From 0fc0cd7ce754731fcf2fc858a72ea7081d4224f6 Mon Sep 17 00:00:00 2001 From: Ville Laitila Date: Sat, 28 Mar 2026 21:46:18 +0200 Subject: [PATCH 1/4] =?UTF-8?q?Add=20SGraph=20Query=20Language=20(P1+P2)?= =?UTF-8?q?=20=E2=80=94=20architecture-native=20filtering=20and=20dependen?= =?UTF-8?q?cy=20queries?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parser: recursive descent with 9-level priority (OR, AND, parens, ---, --->, -->, --, NOT, attribute filters, keyword/path fallback). Expressions: 14 AST node types covering element selection, attribute filtering, direct/transitive dependency search, shortest path, and boolean logic. Evaluator: ghost-parent view model pattern — result elements preserve correct paths without copying the full ancestor tree. Supports chain search (DFS, max depth 20, cycle prevention) and shortest path (undirected BFS). 78 tests covering parser token recognition, precedence, and integration tests against a realistic multi-module test model with dependency chains. --- src/sgraph/query/__init__.py | 11 + src/sgraph/query/engine.py | 50 +++ src/sgraph/query/evaluator.py | 714 ++++++++++++++++++++++++++++++ src/sgraph/query/expressions.py | 127 ++++++ src/sgraph/query/parser.py | 462 ++++++++++++++++++++ tests/test_query.py | 749 ++++++++++++++++++++++++++++++++ 6 files changed, 2113 insertions(+) create mode 100644 src/sgraph/query/__init__.py create mode 100644 src/sgraph/query/engine.py create mode 100644 src/sgraph/query/evaluator.py create mode 100644 src/sgraph/query/expressions.py create mode 100644 src/sgraph/query/parser.py create mode 100644 tests/test_query.py diff --git a/src/sgraph/query/__init__.py b/src/sgraph/query/__init__.py new file mode 100644 index 0000000..6e903bd --- /dev/null +++ b/src/sgraph/query/__init__.py @@ -0,0 +1,11 @@ +"""SGraph Query Language — filter and traverse SGraph models. + +Public API:: + + from sgraph.query import query + + result = query(model, '@type=file AND @loc>500') +""" +from sgraph.query.engine import query + +__all__ = ['query'] diff --git a/src/sgraph/query/engine.py b/src/sgraph/query/engine.py new file mode 100644 index 0000000..d072f67 --- /dev/null +++ b/src/sgraph/query/engine.py @@ -0,0 +1,50 @@ +"""High-level entry point for the SGraph Query Language.""" +from __future__ import annotations + +from sgraph import SGraph +from sgraph.query.evaluator import evaluate +from sgraph.query.parser import parse + + +def query(model: SGraph, expression: str) -> SGraph: + """Execute an SGraph Query Language expression against a model. + + Parses *expression* into an AST and evaluates it against *model*, + returning a new SGraph containing only the matching elements and the + dependency edges connecting them. + + The original *model* is never mutated. + + Args: + model: The source model to query. + expression: An SGraph QL expression string, e.g.:: + + '@type=file AND @loc>500' + '"/src/web" --> "/src/db"' + '"/" AND NOT "/External"' + + Returns: + A new :class:`~sgraph.SGraph` with matching elements and their + connecting associations. + + Raises: + ValueError: If *expression* cannot be parsed. + + Examples:: + + from sgraph import SGraph + from sgraph.query import query + + model = SGraph.parse_xml_or_zipped_xml('model.xml.zip') + + # All Python files with more than 500 lines + result = query(model, '@type=file AND @loc>500') + + # Dependencies from web module to db module + result = query(model, '"/src/web" --> "/src/db"') + + # Everything except external dependencies + result = query(model, '"/" AND NOT "/External"') + """ + ast = parse(expression) + return evaluate(ast, model, total_model=model) diff --git a/src/sgraph/query/evaluator.py b/src/sgraph/query/evaluator.py new file mode 100644 index 0000000..5a619e7 --- /dev/null +++ b/src/sgraph/query/evaluator.py @@ -0,0 +1,714 @@ +"""Evaluator — turns a parsed Expression into a filtered SGraph. + +## View Model Design + +All result SGraph instances use a consistent **ghost-parent** structure: + +1. ``result.rootNode.children`` contains ONLY the explicitly matched top-level + elements — no ancestor-only structural nodes. + +2. Each element in the result is a shallow copy: + - **Top-level elements**: ``parent`` is set to the ORIGINAL model's parent + (a "ghost" ancestor chain). This makes ``getPath()`` return the correct + full path, while ``traverseElements`` from ``result.rootNode`` does NOT + visit the ghost ancestors (they are not in any result node's children + list). + - **Nested elements** (children within a subtree copy): ``parent`` points to + their copy parent within the result subtree, providing correct paths for + all descendants. + +3. Associations in the result graph reference result elements (not originals). + +## Operator semantics + +- AND: sequential filter — apply right to result of left +- OR: union — evaluate both on original model, merge top-level elements +- NOT: subtract — evaluate inner on total, prune those paths from current view +""" +from __future__ import annotations + +import re +from typing import Optional + +from sgraph import SElement, SElementAssociation, SGraph +from sgraph.query.expressions import ( + AndExpr, + AttrEqualsExpr, + AttrGtExpr, + AttrLtExpr, + AttrNotEqualsExpr, + AttrRegexExpr, + ChainSearchExpr, + DepSearchExpr, + Expression, + HasAttrExpr, + KeywordExpr, + NotExpr, + OrExpr, + ParenExpr, + ShortestPathExpr, +) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def evaluate(expr: Expression, model: SGraph, total_model: Optional[SGraph] = None) -> SGraph: + """Evaluate *expr* against *model* and return a new filtered SGraph. + + Args: + expr: Parsed AST node. + model: Model to filter (accumulates through AND chains). + total_model: Original unfiltered model for NOT complement. Defaults to *model*. + + Returns: + A new :class:`~sgraph.SGraph` with only matching elements. + """ + total = total_model if total_model is not None else model + + if isinstance(expr, (HasAttrExpr, AttrEqualsExpr, AttrNotEqualsExpr, + AttrGtExpr, AttrLtExpr, AttrRegexExpr)): + return _eval_predicate(expr, model) + if isinstance(expr, KeywordExpr): + return _eval_keyword(expr, model) + if isinstance(expr, AndExpr): + left = evaluate(expr.left, model, total) + return evaluate(expr.right, left, total) + if isinstance(expr, OrExpr): + left = evaluate(expr.left, model, total) + right = evaluate(expr.right, model, total) + return _union(left, right) + if isinstance(expr, NotExpr): + return _eval_not(expr, model, total) + if isinstance(expr, ParenExpr): + return evaluate(expr.inner, model, total) + if isinstance(expr, DepSearchExpr): + return _eval_dep_search(expr, model, total) + if isinstance(expr, ChainSearchExpr): + return _eval_chain_search(expr, model, total) + if isinstance(expr, ShortestPathExpr): + return _eval_shortest_path(expr, model, total) + + raise ValueError(f"Unknown expression type: {type(expr)}") + + +# --------------------------------------------------------------------------- +# Element copy helpers +# --------------------------------------------------------------------------- + +def _make_copy(src: SElement, parent: Optional[SElement]) -> SElement: + """Create a shallow copy of *src* with the given *parent*.""" + copy: SElement = object.__new__(SElement) + copy.name = src.name + copy.parent = parent + copy.children = [] + copy.childrenDict = {} + copy.outgoing = [] + copy.incoming = [] + copy._incoming_index = {} + copy.attrs = dict(src.attrs) + copy.human_readable_name = src.human_readable_name + return copy + + +def _subtree_copy(src: SElement, parent: Optional[SElement]) -> SElement: + """Recursively copy *src* and all its descendants. + + The returned copy has ``parent = parent`` and its children are copies + of *src*'s children (with correct parent pointers). + """ + copy = _make_copy(src, parent) + for child in src.children: + child_copy = _subtree_copy(child, copy) + copy.children.append(child_copy) + copy.childrenDict[child_copy.name] = child_copy + return copy + + +def _new_result() -> SGraph: + return SGraph(SElement(None, '')) + + +def _all_paths(graph: SGraph) -> set[str]: + paths: set[str] = set() + graph.rootNode.traverseElements(lambda e: paths.add(e.getPath())) + return paths + + +def _all_elements(graph: SGraph) -> list[SElement]: + elems: list[SElement] = [] + + def visit(e: SElement) -> None: + if e.parent is not None: + elems.append(e) + + graph.rootNode.traverseElements(visit) + return elems + + +def _add_subtree(result: SGraph, src_elem: SElement) -> None: + """Add *src_elem* and its full subtree as a top-level entry in *result*. + + Uses *src_elem*'s original parent as a ghost for path resolution: + the copy's parent is set to ``src_elem.parent`` (not ``result.rootNode``), + so ``getPath()`` returns the full original path. The copy IS in + ``result.rootNode.children`` for traversal. + """ + top = _subtree_copy(src_elem, src_elem.parent) # ghost parent = original parent + result.rootNode.children.append(top) + + +def _add_flat(result: SGraph, src_elem: SElement) -> None: + """Add a single element (no children) to *result* as a flat ghost copy. + + The copy's parent is ``src_elem.parent`` (ghost), so ``getPath()`` works. + The copy's children list is empty — ``traverseElements`` does NOT recurse. + """ + copy = _make_copy(src_elem, src_elem.parent) + result.rootNode.children.append(copy) + + +def _is_flat_model(model: SGraph) -> bool: + """True if the model was produced by an attribute predicate (flat output). + + A flat model's top-level elements have ghost parents AND no children in + the result graph (they were added via ``_add_flat`` which creates empty + children lists). Subtree results ALSO use ghost parents but DO have + children in the result. + """ + for child in model.rootNode.children: + if child.children: + return False # has children → not flat + # If ALL top-level elements have no children, treat as flat + # (even a subtree of leaf elements would work correctly as flat) + return True + + +# --------------------------------------------------------------------------- +# Attribute filters — flat results +# --------------------------------------------------------------------------- + +def _attr_value(elem: SElement, attr_name: str) -> Optional[str]: + if attr_name == 'name': + return elem.name + if attr_name == 'type': + t = elem.getType() + return t if t else None + raw = elem.attrs.get(attr_name) + if raw is None: + return None + return ';'.join(str(v) for v in raw) if isinstance(raw, list) else str(raw) + + +def _matches(expr: Expression, elem: SElement) -> bool: + if isinstance(expr, HasAttrExpr): + if expr.attr_name == 'name': + return True + if expr.attr_name == 'type': + return bool(elem.getType()) + return expr.attr_name in elem.attrs + + attr_name: str = getattr(expr, 'attr_name', '') + val = _attr_value(elem, attr_name) + + if val is None: + return isinstance(expr, AttrNotEqualsExpr) # absent → trivially "not equals" + + if isinstance(expr, AttrEqualsExpr): + return val == expr.value if expr.exact else expr.value.lower() in val.lower() + if isinstance(expr, AttrNotEqualsExpr): + return val != expr.value if expr.exact else expr.value.lower() not in val.lower() + if isinstance(expr, AttrGtExpr): + try: + return float(val) > expr.value + except (ValueError, TypeError): + return False + if isinstance(expr, AttrLtExpr): + try: + return float(val) < expr.value + except (ValueError, TypeError): + return False + if isinstance(expr, AttrRegexExpr): + return bool(re.search(expr.pattern, val)) + + raise ValueError(f"Not a predicate expression: {type(expr)}") + + +def _eval_predicate(expr: Expression, model: SGraph) -> SGraph: + """Return a flat result: one ghost copy per matched element, no ancestors.""" + result = _new_result() + seen: set[str] = set() + + def visit(elem: SElement) -> None: + if elem.parent is None: + return + if _matches(expr, elem): + path = elem.getPath() + if path not in seen: + seen.add(path) + _add_flat(result, elem) + + model.rootNode.traverseElements(visit) + return result + + +# --------------------------------------------------------------------------- +# Keyword / path — subtree results +# --------------------------------------------------------------------------- + +def _eval_keyword(expr: KeywordExpr, model: SGraph) -> SGraph: + result = _new_result() + kw = expr.keyword + seen: set[str] = set() + + if expr.exact: + if kw == '/': + for child in model.rootNode.children: + if child.getPath() not in seen: + seen.add(child.getPath()) + _add_subtree(result, child) + return result + + if kw.endswith('/**'): + base = model.findElementFromPath(kw[:-3]) + if base is not None: + for child in base.children: + if child.getPath() not in seen: + seen.add(child.getPath()) + _add_subtree(result, child) + return result + + if kw.endswith('/*'): + base = model.findElementFromPath(kw[:-2]) + if base is not None: + for child in base.children: + if child.getPath() not in seen: + seen.add(child.getPath()) + # Single-level: just the element (no children) + copy = _make_copy(child, child.parent) + result.rootNode.children.append(copy) + return result + + elem = model.findElementFromPath(kw) + if elem is not None: + _add_subtree(result, elem) + return result + + # Keyword search + ends_with = kw.endswith('$') + term = kw[:-1].lower() if ends_with else kw.lower() + + def visit(elem: SElement) -> None: + if elem.parent is None: + return + n = elem.name.lower() + if n.endswith(term) if ends_with else term in n: + path = elem.getPath() + if path not in seen: + seen.add(path) + _add_subtree(result, elem) + + model.rootNode.traverseElements(visit) + return result + + +# --------------------------------------------------------------------------- +# Logical operators +# --------------------------------------------------------------------------- + +def _union(a: SGraph, b: SGraph) -> SGraph: + """Merge two results, deduplicating by top-level path.""" + result = _new_result() + seen: set[str] = set() + a_flat = _is_flat_model(a) + b_flat = _is_flat_model(b) + + if a_flat or b_flat: + # Flat union: preserve ghost parents + for src in (a, b): + for elem in src.rootNode.children: + path = elem.getPath() + if path not in seen: + seen.add(path) + _add_flat(result, elem) + else: + # Hierarchical union: preserve subtrees with ghost parents + for src in (a, b): + for top_elem in src.rootNode.children: + path = top_elem.getPath() + if path not in seen: + seen.add(path) + _add_subtree(result, _find_original_top(top_elem)) + + return result + + +def _find_original_top(elem: SElement) -> SElement: + """Return the original-model element corresponding to this result element. + + For subtree results, the copy's ghost parent chain leads back to the + original parent. We reconstruct the original element by finding the + element whose name matches in the parent's childrenDict. + """ + if elem.parent is None: + return elem + original_parent = elem.parent + if elem.name in original_parent.childrenDict: + return original_parent.childrenDict[elem.name] + # Fallback: elem itself (may be a copy) + return elem + + +def _eval_not(expr: NotExpr, model: SGraph, total: SGraph) -> SGraph: + """NOT: subtract inner-matched paths from current model view.""" + inner = evaluate(expr.inner, total, total) + # Collect excluded paths from the actual matched subtrees (not ancestors) + excluded: set[str] = set() + for top_elem in inner.rootNode.children: + top_elem.traverseElements(lambda e: excluded.add(e.getPath())) + + result = _new_result() + is_flat = _is_flat_model(model) + + if is_flat: + seen: set[str] = set() + for elem in model.rootNode.children: + path = elem.getPath() + if path not in excluded and path not in seen: + seen.add(path) + _add_flat(result, elem) + else: + for top_elem in model.rootNode.children: + _prune_into(top_elem, result.rootNode, excluded, + ghost_parent=top_elem.parent) + + return result + + +def _prune_into( + src: SElement, + dst_parent: SElement, + excluded: set[str], + ghost_parent: Optional[SElement] = None, +) -> None: + """Recursively copy *src* into *dst_parent*, skipping excluded paths.""" + if src.getPath() in excluded: + return + + copy = _make_copy(src, ghost_parent if ghost_parent is not None else dst_parent) + dst_parent.children.append(copy) + if copy.name not in dst_parent.childrenDict: + dst_parent.childrenDict[copy.name] = copy + + for child in src.children: + # Children use copy as parent (proper hierarchy within result subtree) + _prune_into(child, copy, excluded) + + +# --------------------------------------------------------------------------- +# Dependency search +# --------------------------------------------------------------------------- + +def _eval_dep_search(expr: DepSearchExpr, model: SGraph, total: SGraph) -> SGraph: + """Find direct dependencies FROM → TO in *total*.""" + from_originals = _resolve_endpoint(expr.from_expr, total) + to_originals = _resolve_endpoint(expr.to_expr, total) + + from_set = _descendants(from_originals) + to_set = _descendants(to_originals) + to_paths: set[str] = {e.getPath() for e in to_set} + + def assoc_ok(a: SElementAssociation) -> bool: + if a.toElement.getPath() not in to_paths: + return False + if expr.dep_type is not None and a.deptype != expr.dep_type: + return False + if expr.dep_attr_name is not None: + av = a.attrs.get(expr.dep_attr_name) + if av is None and expr.dep_attr_name == 'type': + av = a.deptype + if av is None: + return False + if expr.dep_attr_value is not None and str(av) != expr.dep_attr_value: + return False + return True + + matched: list[SElementAssociation] = [] + for fe in from_set: + for a in fe.outgoing: + if assoc_ok(a): + matched.append(a) + + if not expr.directed: + from_paths: set[str] = {e.getPath() for e in from_set} + + def rev_ok(a: SElementAssociation) -> bool: + if a.toElement.getPath() not in from_paths: + return False + if expr.dep_type is not None and a.deptype != expr.dep_type: + return False + if expr.dep_attr_name is not None: + av = a.attrs.get(expr.dep_attr_name) + if av is None and expr.dep_attr_name == 'type': + av = a.deptype + if av is None: + return False + if expr.dep_attr_value is not None and str(av) != expr.dep_attr_value: + return False + return True + + for te in to_set: + for a in te.outgoing: + if rev_ok(a): + matched.append(a) + + # Build result: shallow copies of matched endpoints + new associations + result = _new_result() + elem_map: dict[str, SElement] = {} + + def get_copy(original: SElement) -> SElement: + path = original.getPath() + if path in elem_map: + return elem_map[path] + copy = _make_copy(original, original.parent) # ghost parent for path + elem_map[path] = copy + result.rootNode.children.append(copy) + return copy + + for assoc in matched: + from_copy = get_copy(assoc.fromElement) + to_copy = get_copy(assoc.toElement) + key = (id(from_copy), assoc.deptype) + if key not in to_copy._incoming_index: + new_assoc = SElementAssociation(from_copy, to_copy, assoc.deptype, dict(assoc.attrs)) + new_assoc.initElems() + + return result + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _resolve_endpoint( + endpoint_expr: Optional[Expression], + total: SGraph, +) -> list[SElement]: + """Evaluate endpoint expression; return ORIGINAL model elements. + + Maps result elements back to originals so outgoing associations are intact. + None means wildcard — return all children of total.rootNode. + """ + if endpoint_expr is None: + return list(total.rootNode.children) + + filtered = evaluate(endpoint_expr, total, total) + result = [] + for elem in filtered.rootNode.children: + original = total.findElementFromPath(elem.getPath()) + if original is not None: + result.append(original) + else: + # Fallback: the element IS from the original (flat copy uses ghost parent) + # Reconstruct original via ghost parent chain + if elem.parent is not None and elem.name in elem.parent.childrenDict: + result.append(elem.parent.childrenDict[elem.name]) + return result + + +def _descendants(elements: list[SElement]) -> set[SElement]: + """Return *elements* and all descendants (from original model traversal).""" + result: set[SElement] = set() + for e in elements: + e.traverseElements(lambda x: result.add(x)) + return result + + +# --------------------------------------------------------------------------- +# Chain search ( ---> ) +# --------------------------------------------------------------------------- + +_CHAIN_MAX_DEPTH = 20 + + +def _eval_chain_search(expr: ChainSearchExpr, model: SGraph, total: SGraph) -> SGraph: + """Find all directed multi-hop chains FROM → ... → TO via DFS.""" + from_originals = _resolve_endpoint(expr.from_expr, total) + to_originals = _resolve_endpoint(expr.to_expr, total) + is_wildcard_to = expr.to_expr is None + + from_set = _descendants(from_originals) + to_set = _descendants(to_originals) if not is_wildcard_to else set() + to_paths: set[str] = {e.getPath() for e in to_set} if not is_wildcard_to else set() + + def edge_ok(a: SElementAssociation) -> bool: + if expr.dep_type is not None and a.deptype != expr.dep_type: + return False + if expr.dep_attr_name is not None: + av = a.attrs.get(expr.dep_attr_name) + if av is None and expr.dep_attr_name == 'type': + av = a.deptype + if av is None: + return False + if expr.dep_attr_value is not None and str(av) != expr.dep_attr_value: + return False + return True + + # DFS to find all chains. Collect all associations on any chain. + chain_assocs: list[SElementAssociation] = [] + found_assoc_ids: set[int] = set() + + def is_target(path: str) -> bool: + if is_wildcard_to: + return path not in {e.getPath() for e in from_set} + return path in to_paths + + def dfs(elem: SElement, visited: set[str], chain: list[SElementAssociation], depth: int) -> None: + if depth > _CHAIN_MAX_DEPTH: + return + for a in elem.outgoing: + if not edge_ok(a): + continue + target = a.toElement + target_path = target.getPath() + if target_path in visited: + continue # cycle prevention + + if is_target(target_path): + # Found a chain — record all associations on this path + for ca in chain: + aid = id(ca) + if aid not in found_assoc_ids: + found_assoc_ids.add(aid) + chain_assocs.append(ca) + aid = id(a) + if aid not in found_assoc_ids: + found_assoc_ids.add(aid) + chain_assocs.append(a) + # Continue DFS through this node too (more chains possible) + new_visited = visited | {target_path} + dfs(target, new_visited, chain + [a], depth + 1) + else: + # Continue DFS + new_visited = visited | {target_path} + dfs(target, new_visited, chain + [a], depth + 1) + + for fe in from_set: + dfs(fe, {fe.getPath()}, [], 0) + + # Build result graph from collected associations + result = _new_result() + elem_map: dict[str, SElement] = {} + + def get_copy(original: SElement) -> SElement: + path = original.getPath() + if path in elem_map: + return elem_map[path] + copy = _make_copy(original, original.parent) + elem_map[path] = copy + result.rootNode.children.append(copy) + return copy + + for assoc in chain_assocs: + from_copy = get_copy(assoc.fromElement) + to_copy = get_copy(assoc.toElement) + new_assoc = SElementAssociation(from_copy, to_copy, assoc.deptype, dict(assoc.attrs)) + new_assoc.initElems() + + return result + + +# --------------------------------------------------------------------------- +# Shortest path ( --- ) +# --------------------------------------------------------------------------- + +def _eval_shortest_path(expr: ShortestPathExpr, model: SGraph, total: SGraph) -> SGraph: + """Find shortest undirected path FROM → ... → TO via BFS.""" + from_originals = _resolve_endpoint(expr.from_expr, total) + to_originals = _resolve_endpoint(expr.to_expr, total) + + if not from_originals or not to_originals: + return _new_result() + + from_set = _descendants(from_originals) + to_paths: set[str] = {e.getPath() for e in _descendants(to_originals)} + + # BFS from each from-element, treating graph as undirected + # Returns the first (shortest) path found as a list of (element, association) pairs + from collections import deque + + best_path: Optional[list[tuple[SElement, Optional[SElementAssociation]]]] = None + + for start in from_set: + start_path = start.getPath() + if start_path in to_paths: + # Trivial case: start is already in to-set + best_path = [(start, None)] + break + + queue: deque[tuple[SElement, list[tuple[SElement, Optional[SElementAssociation]]]]] = deque() + queue.append((start, [(start, None)])) + visited: set[str] = {start_path} + + while queue: + current, path = queue.popleft() + if best_path is not None and len(path) >= len(best_path): + break # can't beat the best already found + + # Explore outgoing + for a in current.outgoing: + neighbor = a.toElement + npath = neighbor.getPath() + if npath in visited: + continue + new_path = path + [(neighbor, a)] + if npath in to_paths: + if best_path is None or len(new_path) < len(best_path): + best_path = new_path + break + visited.add(npath) + queue.append((neighbor, new_path)) + + # Explore incoming (undirected) + for a in current.incoming: + neighbor = a.fromElement + npath = neighbor.getPath() + if npath in visited: + continue + new_path = path + [(neighbor, a)] + if npath in to_paths: + if best_path is None or len(new_path) < len(best_path): + best_path = new_path + break + visited.add(npath) + queue.append((neighbor, new_path)) + + if best_path is not None and len(best_path) <= 2: + break # optimal: direct neighbor + + if best_path is None: + return _new_result() + + # Build result graph from the path + result = _new_result() + elem_map: dict[str, SElement] = {} + + def get_copy(original: SElement) -> SElement: + path = original.getPath() + if path in elem_map: + return elem_map[path] + copy = _make_copy(original, original.parent) + elem_map[path] = copy + result.rootNode.children.append(copy) + return copy + + for elem, assoc in best_path: + get_copy(elem) + if assoc is not None: + from_copy = get_copy(assoc.fromElement) + to_copy = get_copy(assoc.toElement) + new_assoc = SElementAssociation(from_copy, to_copy, assoc.deptype, dict(assoc.attrs)) + new_assoc.initElems() + + return result diff --git a/src/sgraph/query/expressions.py b/src/sgraph/query/expressions.py new file mode 100644 index 0000000..6dbdd34 --- /dev/null +++ b/src/sgraph/query/expressions.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class KeywordExpr: + """Case-insensitive partial name match, or exact quoted path lookup.""" + keyword: str + exact: bool # True if the keyword was a quoted path like "/path/to/elem" + + +@dataclass +class HasAttrExpr: + """@attr — element has the attribute (any value).""" + attr_name: str + + +@dataclass +class AttrEqualsExpr: + """@attr=value or @attr="exact" — attribute equals/contains.""" + attr_name: str + value: str + exact: bool # True if value was quoted → exact match; False → case-insensitive contains + + +@dataclass +class AttrNotEqualsExpr: + """@attr!=value — attribute does not equal/contain value.""" + attr_name: str + value: str + exact: bool + + +@dataclass +class AttrGtExpr: + """@attr>number — numeric greater-than comparison.""" + attr_name: str + value: float + + +@dataclass +class AttrLtExpr: + """@attr TO or FROM -- TO — direct dependency search. + + from_expr=None means wildcard (any element as source). + to_expr=None means wildcard (any element as target). + dep_type is shorthand for dep_attr_value when dep_attr_name='type'. + """ + from_expr: Optional[Expression] # None = wildcard * + to_expr: Optional[Expression] # None = wildcard * + directed: bool # True for -->, False for -- + dep_type: Optional[str] = None # shorthand: -deptype-> + dep_attr_name: Optional[str] = None # -@attr-> attribute filter on the edge + dep_attr_value: Optional[str] = None # -@attr=value-> value filter on the edge + + +@dataclass +class ChainSearchExpr: + """FROM ---> TO — find all directed multi-hop chains (transitive paths). + + DFS traversal with cycle detection and max depth limit. + """ + from_expr: Optional[Expression] # None = wildcard * + to_expr: Optional[Expression] # None = wildcard * + dep_type: Optional[str] = None + dep_attr_name: Optional[str] = None + dep_attr_value: Optional[str] = None + + +@dataclass +class ShortestPathExpr: + """FROM --- TO — find shortest undirected path between two elements. + + BFS ignoring edge direction. + """ + from_expr: Optional[Expression] # None = wildcard * + to_expr: Optional[Expression] # None = wildcard * + + +# Union type for all expression nodes — used as type hints throughout. +Expression = ( + KeywordExpr | HasAttrExpr | AttrEqualsExpr | AttrNotEqualsExpr | AttrGtExpr | AttrLtExpr + | AttrRegexExpr | AndExpr | OrExpr | NotExpr | ParenExpr | DepSearchExpr + | ChainSearchExpr | ShortestPathExpr +) diff --git a/src/sgraph/query/parser.py b/src/sgraph/query/parser.py new file mode 100644 index 0000000..29fa14c --- /dev/null +++ b/src/sgraph/query/parser.py @@ -0,0 +1,462 @@ +"""Recursive descent parser for the SGraph Query Language. + +Priority order (first match wins): + 1. OR — tried first so AND binds tighter (standard math precedence) + 2. AND — sequential filter, binds tighter than OR + 3. Parentheses + 4. Shortest Path (--- undirected BFS) + 5. Chain Search (---> all directed multi-hop paths) + 6. Dep Search (-->, --, and typed variants -type-> etc.) + 7. NOT + 8. Attribute filters (@attr=~"…", @attr=…, @attr!=…, @attr>…, @attr<…, @attr) + 9. Keyword / Exact Path (fallback) + +Both OR and AND splitting respect parentheses: the operator must appear +at nesting depth 0 (not inside any parenthesised group). This allows +expressions like ``(@type=file OR @type=dir) AND @loc>100`` to parse +correctly as AND(Paren(OR(…)), GT(…)). +""" +from __future__ import annotations + +import re +from typing import Optional + +from sgraph.query.expressions import ( + AndExpr, + AttrEqualsExpr, + AttrGtExpr, + AttrLtExpr, + AttrNotEqualsExpr, + AttrRegexExpr, + ChainSearchExpr, + DepSearchExpr, + Expression, + HasAttrExpr, + KeywordExpr, + NotExpr, + OrExpr, + ParenExpr, + ShortestPathExpr, +) + + +def parse(expression: str) -> Expression: + """Parse an SGraph Query Language string into an AST. + + Args: + expression: Raw query string, e.g. ``'@type=file AND @loc>500'``. + + Returns: + The root AST node representing the full expression. + + Raises: + ValueError: If the expression cannot be parsed. + """ + s = expression.strip() + if not s: + raise ValueError("Empty query expression") + return _parse(s) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _parse(s: str) -> Expression: + """Core dispatch — tries each rule in priority order.""" + s = s.strip() + + result = (_try_or(s) or _try_and(s) or _try_paren(s) + or _try_shortest_path(s) or _try_chain_search(s) + or _try_dep_search(s) + or _try_not(s) or _try_attr(s) or _try_keyword(s)) + + if result is None: + raise ValueError(f"Cannot parse expression: {s!r}") + return result + + +def _find_top_level_operator(s: str, op: str) -> int: + """Find the index of *op* at parenthesis depth 0, or -1 if not found. + + Scans left-to-right, tracking paren depth and skipping quoted regions. + Returns the start index of the first match at depth 0. + """ + depth = 0 + in_quote = False + quote_char = '' + i = 0 + while i < len(s): + ch = s[i] + if in_quote: + if ch == quote_char: + in_quote = False + i += 1 + continue + if ch in ('"', "'"): + in_quote = True + quote_char = ch + i += 1 + continue + if ch == '(': + depth += 1 + elif ch == ')': + depth -= 1 + elif depth == 0 and s[i:i + len(op)] == op: + return i + i += 1 + return -1 + + +# --------------------------------------------------------------------------- +# 1. OR (tried first so AND binds tighter — standard precedence) +# --------------------------------------------------------------------------- + +def _try_or(s: str) -> Optional[Expression]: + """Split on first top-level ` OR ` (spaces required).""" + idx = _find_top_level_operator(s, ' OR ') + if idx == -1: + return None + left_src = s[:idx].strip() + right_src = s[idx + 4:].strip() + return OrExpr(left=_parse(left_src), right=_parse(right_src)) + + +# --------------------------------------------------------------------------- +# 2. AND +# --------------------------------------------------------------------------- + +def _try_and(s: str) -> Optional[Expression]: + """Split on first top-level ` AND ` (spaces required).""" + idx = _find_top_level_operator(s, ' AND ') + if idx == -1: + return None + left_src = s[:idx].strip() + right_src = s[idx + 5:].strip() + return AndExpr(left=_parse(left_src), right=_parse(right_src)) + + +# --------------------------------------------------------------------------- +# 3. Parentheses +# --------------------------------------------------------------------------- + +def _try_paren(s: str) -> Optional[Expression]: + """Match (…) wrapping the ENTIRE expression.""" + if not (s.startswith('(') and s.endswith(')')): + return None + # Verify the opening paren closes at the very end, not earlier. + depth = 0 + for i, ch in enumerate(s): + if ch == '(': + depth += 1 + elif ch == ')': + depth -= 1 + if depth == 0 and i < len(s) - 1: + # Closing paren found before the end — not a simple wrapper + return None + inner = s[1:-1].strip() + return ParenExpr(inner=_parse(inner)) + + +# --------------------------------------------------------------------------- +# 4. Shortest Path ( --- ) +# --------------------------------------------------------------------------- + +def _try_shortest_path(s: str) -> Optional[Expression]: + """Parse ``FROM --- TO`` (undirected shortest path, BFS). + + Must appear at depth 0, surrounded by spaces: `` --- ``. + Tried before chain search and dep search to avoid ambiguity. + """ + idx = _find_top_level_operator(s, ' --- ') + if idx == -1: + return None + # Verify this is exactly --- not ---- or ---> + after = idx + 5 # position after ' --- ' + before = idx # position of ' ' + # Check the character before the operator is not '-' (would be ----) + if before > 0 and s[before - 1] == '-': + return None + # Check the character after is not '-' or '>' (would be ----> or ---->) + if after < len(s) and s[after] in ('-', '>'): + return None + + from_src = s[:idx].strip() + to_src = s[after:].strip() + return ShortestPathExpr( + from_expr=_parse_dep_endpoint(from_src), + to_expr=_parse_dep_endpoint(to_src), + ) + + +# --------------------------------------------------------------------------- +# 5. Chain Search ( ---> and --type-> ) +# --------------------------------------------------------------------------- + +# Matches: " ---> " or " --label-> " (label between the -- and ->) +_CHAIN_PATTERN = re.compile( + r' (' + r'--->' # plain chain + r'|--[^-\s>][^>]*->' # --label-> (label between -- and ->) + r') ' +) + + +def _try_chain_search(s: str) -> Optional[Expression]: + """Parse ``FROM ---> TO`` and ``FROM --type-> TO`` (chain search, all paths).""" + in_quote = False + quote_char = '' + depth = 0 + i = 0 + op_match: Optional[re.Match[str]] = None + while i < len(s): + ch = s[i] + if in_quote: + if ch == quote_char: + in_quote = False + i += 1 + continue + if ch in ('"', "'"): + in_quote = True + quote_char = ch + i += 1 + continue + if ch == '(': + depth += 1 + elif ch == ')': + depth -= 1 + if depth == 0: + m = _CHAIN_PATTERN.match(s, i) + if m: + op_match = m + break + i += 1 + + if op_match is None: + return None + + op_text = op_match.group(1) + from_src = s[:op_match.start()].strip() + to_src = s[op_match.end():].strip() + + from_expr = _parse_dep_endpoint(from_src) + to_expr = _parse_dep_endpoint(to_src) + + dep_type: Optional[str] = None + dep_attr_name: Optional[str] = None + dep_attr_value: Optional[str] = None + + # Extract label: strip leading --, trailing -> + inner = op_text[2:] # remove leading -- + if inner.endswith('>'): + inner = inner[:-1] + inner = inner.rstrip('-') + + if inner: + if inner.startswith('@'): + attr_part = inner[1:] + if '=' in attr_part: + dep_attr_name, dep_attr_value = attr_part.split('=', 1) + else: + dep_attr_name = attr_part + else: + dep_type = inner + + return ChainSearchExpr( + from_expr=from_expr, + to_expr=to_expr, + dep_type=dep_type, + dep_attr_name=dep_attr_name, + dep_attr_value=dep_attr_value, + ) + + +# --------------------------------------------------------------------------- +# 6. Dependency search (--> -- -type-> -@attr-> -@attr=val->) +# --------------------------------------------------------------------------- + +# Pattern matches one of: +# " --> " plain directed +# " -label-> " directed with label (type or @attr) +# " -- " plain undirected (not --- or --->) +# " -label- " undirected with label +# Operator must be surrounded by spaces (spec requirement). +_DEP_PATTERN = re.compile( + r' (' + r'-->' # plain directed + r'|-[^-\s>][^>]*->' # -label-> (label cannot start with -) + r'|--(?![-\->])' # plain undirected: -- not followed by - or > + r'|-[^-\s>][^>]*-(?!>)' # -label- (undirected, not followed by >) + r') ' +) + + +def _try_dep_search(s: str) -> Optional[Expression]: + """Parse FROM --> TO, FROM -- TO and typed/attribute variants.""" + # Walk the string to find the operator, skipping quoted and parenthesised regions. + in_quote = False + quote_char = '' + depth = 0 + i = 0 + op_match: Optional[re.Match[str]] = None + while i < len(s): + ch = s[i] + if in_quote: + if ch == quote_char: + in_quote = False + i += 1 + continue + if ch in ('"', "'"): + in_quote = True + quote_char = ch + i += 1 + continue + if ch == '(': + depth += 1 + elif ch == ')': + depth -= 1 + if depth == 0: + m = _DEP_PATTERN.match(s, i) + if m: + op_match = m + break + i += 1 + + if op_match is None: + return None + + op_text = op_match.group(1) # the operator token without surrounding spaces + from_src = s[:op_match.start()].strip() + to_src = s[op_match.end():].strip() + + from_expr = _parse_dep_endpoint(from_src) + to_expr = _parse_dep_endpoint(to_src) + + directed = op_text.endswith('>') + dep_type: Optional[str] = None + dep_attr_name: Optional[str] = None + dep_attr_value: Optional[str] = None + + # Extract the label between the dashes + inner = op_text.lstrip('-') + if inner.endswith('>'): + inner = inner[:-1] + inner = inner.rstrip('-') + + if inner: + if inner.startswith('@'): + attr_part = inner[1:] + if '=' in attr_part: + dep_attr_name, dep_attr_value = attr_part.split('=', 1) + else: + dep_attr_name = attr_part + else: + dep_type = inner + + return DepSearchExpr( + from_expr=from_expr, + to_expr=to_expr, + directed=directed, + dep_type=dep_type, + dep_attr_name=dep_attr_name, + dep_attr_value=dep_attr_value, + ) + + +def _parse_dep_endpoint(s: str) -> Optional[Expression]: + """Parse a dep-search endpoint. Returns None for wildcard ``*``.""" + if s in ('*', '"*"', ''): + return None + return _parse(s) + + +# --------------------------------------------------------------------------- +# 5. NOT +# --------------------------------------------------------------------------- + +def _try_not(s: str) -> Optional[Expression]: + """Match ``NOT ``.""" + if not s.startswith('NOT '): + return None + inner = s[4:].strip() + return NotExpr(inner=_parse(inner)) + + +# --------------------------------------------------------------------------- +# 6. Attribute filters +# --------------------------------------------------------------------------- + +# @attr=~"pattern" +_RE_REGEX_MATCH = re.compile(r'^@([\w\-]+)=~"(.+)"$') + +# @attr="exact value" — quoted exact match +_RE_ATTR_EQ_QUOTED = re.compile(r'^@([\w\-]+)="([^"]*)"$') + +# @attr=value — unquoted, contains match (value cannot start with ~ or contain ") +_RE_ATTR_EQ_UNQUOTED = re.compile(r'^@([\w\-]+)=([^"~][^"]*)$') + +# @attr!="exact" or @attr!=value +_RE_ATTR_NEQ_QUOTED = re.compile(r'^@([\w\-]+)!="([^"]*)"$') +_RE_ATTR_NEQ_UNQUOTED = re.compile(r'^@([\w\-]+)!=([^"]+)$') + +# @attr>number +_RE_ATTR_GT = re.compile(r'^@([\w\-]+)>([\d.]+)$') + +# @attr Optional[Expression]: + """Try all attribute filter patterns in priority order.""" + # Regex match must be tried before equals (both start with @attr=) + m = _RE_REGEX_MATCH.match(s) + if m: + return AttrRegexExpr(attr_name=m.group(1), pattern=m.group(2)) + + # Not-equals (quoted then unquoted) + m = _RE_ATTR_NEQ_QUOTED.match(s) + if m: + return AttrNotEqualsExpr(attr_name=m.group(1), value=m.group(2), exact=True) + + m = _RE_ATTR_NEQ_UNQUOTED.match(s) + if m: + return AttrNotEqualsExpr(attr_name=m.group(1), value=m.group(2), exact=False) + + m = _RE_ATTR_GT.match(s) + if m: + return AttrGtExpr(attr_name=m.group(1), value=float(m.group(2))) + + m = _RE_ATTR_LT.match(s) + if m: + return AttrLtExpr(attr_name=m.group(1), value=float(m.group(2))) + + # Equals: quoted before unquoted + m = _RE_ATTR_EQ_QUOTED.match(s) + if m: + return AttrEqualsExpr(attr_name=m.group(1), value=m.group(2), exact=True) + + m = _RE_ATTR_EQ_UNQUOTED.match(s) + if m: + return AttrEqualsExpr(attr_name=m.group(1), value=m.group(2), exact=False) + + m = _RE_HAS_ATTR.match(s) + if m: + return HasAttrExpr(attr_name=m.group(1)) + + return None + + +# --------------------------------------------------------------------------- +# 7. Keyword / Exact Path fallback +# --------------------------------------------------------------------------- + +def _try_keyword(s: str) -> Optional[Expression]: + """Fallback: quoted string = exact path, unquoted = keyword search.""" + if s.startswith('"') and s.endswith('"') and len(s) >= 2: + path = s[1:-1] + return KeywordExpr(keyword=path, exact=True) + # Treat anything else as a keyword. + return KeywordExpr(keyword=s, exact=False) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..52d23f8 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,749 @@ +"""Tests for the SGraph Query Language (SGL) P1+P2 implementation. + +Covers: +- Parser: token recognition, precedence, all expression types +- Evaluator: integration tests against a realistic test model +- P2: chain search (--->) and shortest path (---) +""" +from __future__ import annotations + +import pytest + +from sgraph import SGraph, SElement, SElementAssociation +from sgraph.query import query +from sgraph.query.expressions import ( + AndExpr, + AttrEqualsExpr, + AttrGtExpr, + AttrLtExpr, + AttrNotEqualsExpr, + AttrRegexExpr, + ChainSearchExpr, + DepSearchExpr, + HasAttrExpr, + KeywordExpr, + NotExpr, + OrExpr, + ParenExpr, + ShortestPathExpr, +) +from sgraph.query.parser import parse + + +# --------------------------------------------------------------------------- +# Test model factory +# --------------------------------------------------------------------------- + +def create_test_model() -> SGraph: + """Build a small but representative model for query evaluation tests. + + Structure: + /project (repository) + /project/src (dir) + /project/src/web (dir) + app.py (file, loc=500) + views.py (file, loc=200) + /project/src/db (dir) + models.py (file, loc=300) + queries.py (file, loc=150) + /project/src/common (dir) + utils.py (file, loc=50) + /project/test (dir) + test_app.py (file, loc=100) + /project/External (dir) + flask (package) + + Dependencies: + web/app.py --import--> db/models.py + web/app.py --import--> common/utils.py + web/app.py --import--> External/flask + web/views.py --import--> web/app.py + web/views.py --function_ref--> db/queries.py + test/test_app.py --import--> web/app.py + """ + model = SGraph(SElement(None, '')) + + # Top-level project + project = SElement(model.rootNode, 'project') + project.setType('repository') + + # src subtree + src = SElement(project, 'src') + src.setType('dir') + + web = SElement(src, 'web') + web.setType('dir') + + app = SElement(web, 'app.py') + app.setType('file') + app.attrs['loc'] = '500' + + views = SElement(web, 'views.py') + views.setType('file') + views.attrs['loc'] = '200' + + db = SElement(src, 'db') + db.setType('dir') + + models = SElement(db, 'models.py') + models.setType('file') + models.attrs['loc'] = '300' + + queries_elem = SElement(db, 'queries.py') + queries_elem.setType('file') + queries_elem.attrs['loc'] = '150' + + common = SElement(src, 'common') + common.setType('dir') + + utils = SElement(common, 'utils.py') + utils.setType('file') + utils.attrs['loc'] = '50' + + # test subtree + test_dir = SElement(project, 'test') + test_dir.setType('dir') + + test_app = SElement(test_dir, 'test_app.py') + test_app.setType('file') + test_app.attrs['loc'] = '100' + + # External subtree + external = SElement(project, 'External') + external.setType('dir') + + flask = SElement(external, 'flask') + flask.setType('package') + + # Dependencies + SElementAssociation(app, models, 'import').initElems() + SElementAssociation(app, utils, 'import').initElems() + SElementAssociation(app, flask, 'import').initElems() + SElementAssociation(views, app, 'import').initElems() + SElementAssociation(views, queries_elem, 'function_ref').initElems() + SElementAssociation(test_app, app, 'import').initElems() + + return model + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def get_all_paths(result_model: SGraph) -> set[str]: + """Collect every element path present in the result model.""" + paths: set[str] = set() + result_model.rootNode.traverseElements(lambda e: paths.add(e.getPath())) + return paths + + +def get_all_associations(result_model: SGraph) -> list[SElementAssociation]: + """Collect all associations reachable from the result model's root.""" + assocs: list[SElementAssociation] = [] + + def collect(e: SElement) -> None: + assocs.extend(e.outgoing) + + result_model.rootNode.traverseElements(collect) + # Deduplicate by identity + return list({id(a): a for a in assocs}.values()) + + +# --------------------------------------------------------------------------- +# Parser tests +# --------------------------------------------------------------------------- + +class TestParser: + def test_parse_keyword(self): + expr = parse('phone') + assert isinstance(expr, KeywordExpr) + assert expr.keyword == 'phone' + assert expr.exact is False + + def test_parse_exact_path(self): + expr = parse('"/project/src/web"') + assert isinstance(expr, KeywordExpr) + assert expr.keyword == '/project/src/web' + assert expr.exact is True + + def test_parse_has_attr(self): + expr = parse('@loc') + assert isinstance(expr, HasAttrExpr) + assert expr.attr_name == 'loc' + + def test_parse_attr_equals_unquoted(self): + expr = parse('@type=file') + assert isinstance(expr, AttrEqualsExpr) + assert expr.attr_name == 'type' + assert expr.value == 'file' + assert expr.exact is False + + def test_parse_attr_equals_quoted(self): + expr = parse('@type="file"') + assert isinstance(expr, AttrEqualsExpr) + assert expr.attr_name == 'type' + assert expr.value == 'file' + assert expr.exact is True + + def test_parse_attr_not_equals(self): + expr = parse('@type!=dir') + assert isinstance(expr, AttrNotEqualsExpr) + assert expr.attr_name == 'type' + assert expr.value == 'dir' + + def test_parse_attr_not_equals_quoted(self): + expr = parse('@type!="dir"') + assert isinstance(expr, AttrNotEqualsExpr) + assert expr.attr_name == 'type' + assert expr.value == 'dir' + assert expr.exact is True + + def test_parse_attr_gt(self): + expr = parse('@loc>100') + assert isinstance(expr, AttrGtExpr) + assert expr.attr_name == 'loc' + assert expr.value == 100.0 + + def test_parse_attr_lt(self): + expr = parse('@loc<200') + assert isinstance(expr, AttrLtExpr) + assert expr.attr_name == 'loc' + assert expr.value == 200.0 + + def test_parse_attr_gt_float(self): + expr = parse('@loc>99.5') + assert isinstance(expr, AttrGtExpr) + assert expr.value == 99.5 + + def test_parse_attr_regex(self): + expr = parse('@name=~".*\\.py$"') + assert isinstance(expr, AttrRegexExpr) + assert expr.attr_name == 'name' + assert expr.pattern == '.*\\.py$' + + def test_parse_and(self): + expr = parse('"/project/src" AND @type=file') + assert isinstance(expr, AndExpr) + assert isinstance(expr.left, KeywordExpr) + assert isinstance(expr.right, AttrEqualsExpr) + assert expr.left.keyword == '/project/src' + + def test_parse_or(self): + expr = parse('@type=file OR @type=dir') + assert isinstance(expr, OrExpr) + assert isinstance(expr.left, AttrEqualsExpr) + assert isinstance(expr.right, AttrEqualsExpr) + + def test_parse_not(self): + expr = parse('NOT @type=dir') + assert isinstance(expr, NotExpr) + assert isinstance(expr.inner, AttrEqualsExpr) + + def test_parse_parens(self): + expr = parse('(@type=file OR @type=dir) AND @loc>100') + assert isinstance(expr, AndExpr) + assert isinstance(expr.left, ParenExpr) + assert isinstance(expr.right, AttrGtExpr) + inner = expr.left.inner + assert isinstance(inner, OrExpr) + + def test_parse_dep_directed(self): + expr = parse('"/project/src/web" --> "/project/src/db"') + assert isinstance(expr, DepSearchExpr) + assert expr.directed is True + assert isinstance(expr.from_expr, KeywordExpr) + assert isinstance(expr.to_expr, KeywordExpr) + assert expr.from_expr.keyword == '/project/src/web' + assert expr.to_expr.keyword == '/project/src/db' + + def test_parse_dep_undirected(self): + expr = parse('"/project/src/web" -- "/project/src/db"') + assert isinstance(expr, DepSearchExpr) + assert expr.directed is False + + def test_parse_dep_with_type(self): + expr = parse('"/web" -import-> "/db"') + assert isinstance(expr, DepSearchExpr) + assert expr.directed is True + assert expr.dep_type == 'import' + + def test_parse_wildcard_dep_from(self): + expr = parse('"*" --> "/project/src/db"') + assert isinstance(expr, DepSearchExpr) + # from_expr is None for wildcard + assert expr.from_expr is None + assert isinstance(expr.to_expr, KeywordExpr) + + def test_parse_wildcard_dep_to(self): + expr = parse('"/project/src/web" --> "*"') + assert isinstance(expr, DepSearchExpr) + assert isinstance(expr.from_expr, KeywordExpr) + assert expr.to_expr is None + + def test_parse_precedence_and_before_or(self): + # 'A OR B AND C' should group as 'A OR (B AND C)' since AND binds tighter + expr = parse('app OR views AND @loc>100') + # Top-level must be OR + assert isinstance(expr, OrExpr) + assert isinstance(expr.left, KeywordExpr) + # Right side must be AND (higher precedence) + assert isinstance(expr.right, AndExpr) + + def test_parse_chained_and(self): + expr = parse('@type=file AND @loc>100 AND "/project/src"') + # Both ANDs should be present (left-associative) + assert isinstance(expr, AndExpr) + + def test_parse_keyword_case_preserved(self): + expr = parse('MyModule') + assert isinstance(expr, KeywordExpr) + assert expr.keyword == 'MyModule' + + def test_parse_not_with_parens(self): + expr = parse('NOT (@type=file OR @type=dir)') + assert isinstance(expr, NotExpr) + assert isinstance(expr.inner, ParenExpr) + + +# --------------------------------------------------------------------------- +# Evaluator tests (integration with model) +# --------------------------------------------------------------------------- + +class TestEvaluator: + @pytest.fixture + def model(self) -> SGraph: + return create_test_model() + + # --- Keyword search --- + + def test_keyword_partial_match(self, model: SGraph): + result = query(model, 'app') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/test/test_app.py' in paths + + def test_keyword_no_match(self, model: SGraph): + result = query(model, 'nonexistent_zzz') + paths = get_all_paths(result) + # Only root node (empty graph skeleton) — no real project elements + assert '/project/src/web/app.py' not in paths + + def test_keyword_case_insensitive(self, model: SGraph): + result = query(model, 'FLASK') + paths = get_all_paths(result) + assert '/project/External/flask' in paths + + # --- Exact path match --- + + def test_exact_path_subtree(self, model: SGraph): + result = query(model, '"/project/src/web"') + paths = get_all_paths(result) + assert '/project/src/web' in paths + assert '/project/src/web/app.py' in paths + assert '/project/src/web/views.py' in paths + + def test_exact_path_excludes_siblings(self, model: SGraph): + result = query(model, '"/project/src/web"') + paths = get_all_paths(result) + assert '/project/src/db/models.py' not in paths + assert '/project/test/test_app.py' not in paths + + def test_exact_path_single_file(self, model: SGraph): + result = query(model, '"/project/src/db/models.py"') + paths = get_all_paths(result) + assert '/project/src/db/models.py' in paths + assert '/project/src/db/queries.py' not in paths + + # --- Attribute filtering --- + + def test_attr_type_file(self, model: SGraph): + result = query(model, '@type=file') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/src/db/models.py' in paths + assert '/project/src/db/queries.py' in paths + assert '/project/test/test_app.py' in paths + # Directories must be excluded + assert '/project/src/web' not in paths + assert '/project/src/db' not in paths + + def test_attr_type_package(self, model: SGraph): + result = query(model, '@type=package') + paths = get_all_paths(result) + assert '/project/External/flask' in paths + assert '/project/src/web/app.py' not in paths + + def test_attr_has_loc(self, model: SGraph): + result = query(model, '@loc') + paths = get_all_paths(result) + # All file elements have loc; dirs and packages do not + assert '/project/src/web/app.py' in paths + assert '/project/External/flask' not in paths + + def test_attr_gt(self, model: SGraph): + result = query(model, '@loc>200') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths # loc=500 + assert '/project/src/db/models.py' in paths # loc=300 + assert '/project/src/db/queries.py' not in paths # loc=150 + assert '/project/src/web/views.py' not in paths # loc=200, not >200 + + def test_attr_gt_boundary(self, model: SGraph): + # loc>199 should include views.py (200) and above + result = query(model, '@loc>199') + paths = get_all_paths(result) + assert '/project/src/web/views.py' in paths + + def test_attr_lt(self, model: SGraph): + result = query(model, '@loc<200') + paths = get_all_paths(result) + assert '/project/src/common/utils.py' in paths # loc=50 + assert '/project/test/test_app.py' in paths # loc=100 + assert '/project/src/db/queries.py' in paths # loc=150 + assert '/project/src/web/app.py' not in paths # loc=500 + + def test_attr_not_equals(self, model: SGraph): + result = query(model, '@type!=file') + paths = get_all_paths(result) + # Package and dir elements should appear; files should not + assert '/project/External/flask' in paths + assert '/project/src/web/app.py' not in paths + + # --- Boolean combinators --- + + def test_and_path_and_type(self, model: SGraph): + result = query(model, '"/project/src" AND @type=file') + paths = get_all_paths(result) + # Only files under /project/src + assert '/project/src/web/app.py' in paths + assert '/project/src/db/models.py' in paths + # test_app.py is NOT under /project/src + assert '/project/test/test_app.py' not in paths + + def test_and_path_and_loc(self, model: SGraph): + result = query(model, '"/project/src/db" AND @loc>200') + paths = get_all_paths(result) + assert '/project/src/db/models.py' in paths # loc=300 + assert '/project/src/db/queries.py' not in paths # loc=150 + + def test_or_two_subtrees(self, model: SGraph): + result = query(model, '"/project/src/web" OR "/project/test"') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/src/web/views.py' in paths + assert '/project/test/test_app.py' in paths + assert '/project/src/db/models.py' not in paths + + def test_or_two_types(self, model: SGraph): + result = query(model, '@type=file OR @type=package') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/External/flask' in paths + assert '/project/src/web' not in paths + + def test_not_excludes_subtree(self, model: SGraph): + result = query(model, '"/project/src" AND NOT "/project/src/web"') + paths = get_all_paths(result) + assert '/project/src/db/models.py' in paths + assert '/project/src/common/utils.py' in paths + assert '/project/src/web/app.py' not in paths + assert '/project/src/web/views.py' not in paths + + def test_not_type(self, model: SGraph): + result = query(model, '"/project/src" AND NOT @type=file') + paths = get_all_paths(result) + # Dirs under src should appear, but not files + assert '/project/src/web' in paths + assert '/project/src/db' in paths + assert '/project/src/web/app.py' not in paths + + def test_parens_or_then_and(self, model: SGraph): + result = query(model, '(@type=file OR @type=dir) AND @loc>200') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths # file, loc=500 + assert '/project/src/db/models.py' in paths # file, loc=300 + assert '/project/src/web/views.py' not in paths # loc=200, not >200 + assert '/project/src/common/utils.py' not in paths # loc=50 + + # --- Dependency searches --- + + def test_dep_directed_match(self, model: SGraph): + result = query(model, '"/project/src/web/app.py" --> "/project/src/db/models.py"') + paths = get_all_paths(result) + assocs = get_all_associations(result) + assert '/project/src/web/app.py' in paths + assert '/project/src/db/models.py' in paths + assert len(assocs) >= 1 + + def test_dep_directed_wrong_direction(self, model: SGraph): + # db/models.py does NOT depend on web/app.py — wrong direction + result = query(model, '"/project/src/db/models.py" --> "/project/src/web/app.py"') + assocs = get_all_associations(result) + assert len(assocs) == 0 + + def test_dep_directed_no_such_dep(self, model: SGraph): + # utils.py does not depend on models.py + result = query(model, '"/project/src/common/utils.py" --> "/project/src/db/models.py"') + assocs = get_all_associations(result) + assert len(assocs) == 0 + + def test_dep_undirected_finds_both_directions(self, model: SGraph): + # views.py imports app.py (views → app), so undirected should find it + result = query(model, '"/project/src/web/app.py" -- "/project/src/web/views.py"') + assocs = get_all_associations(result) + assert len(assocs) >= 1 + + def test_dep_with_type_import(self, model: SGraph): + result = query(model, '"/project/src/web" -import-> "/project/src/db"') + assocs = get_all_associations(result) + deptypes = {a.deptype for a in assocs} + # app.py --import--> models.py is within these subtrees + assert 'import' in deptypes + # views.py --function_ref--> queries.py must not appear + assert 'function_ref' not in deptypes + + def test_dep_with_type_function_ref(self, model: SGraph): + result = query(model, '"/project/src/web" -function_ref-> "/project/src/db"') + assocs = get_all_associations(result) + deptypes = {a.deptype for a in assocs} + assert 'function_ref' in deptypes + assert 'import' not in deptypes + + def test_dep_wildcard_from(self, model: SGraph): + result = query(model, '"*" --> "/project/src/db/models.py"') + assocs = get_all_associations(result) + # app.py imports models.py + assert len(assocs) >= 1 + to_paths = {a.toElement.getPath() for a in assocs} + assert '/project/src/db/models.py' in to_paths + + def test_dep_wildcard_to(self, model: SGraph): + result = query(model, '"/project/src/web/app.py" --> "*"') + assocs = get_all_associations(result) + # app.py has 3 outgoing: models.py, utils.py, flask + assert len(assocs) >= 3 + + def test_dep_subtree_to_subtree(self, model: SGraph): + # All deps from /project/src/web to anywhere + result = query(model, '"/project/src/web" --> "*"') + assocs = get_all_associations(result) + from_paths = {a.fromElement.getPath() for a in assocs} + # At least app.py and views.py contribute outgoing deps + assert any('web' in p for p in from_paths) + + # --- Regression: result model contains both endpoints --- + + def test_dep_result_contains_both_endpoints(self, model: SGraph): + result = query(model, '"/project/src/web/views.py" --> "/project/src/web/app.py"') + paths = get_all_paths(result) + assert '/project/src/web/views.py' in paths + assert '/project/src/web/app.py' in paths + + # --- Edge cases --- + + def test_empty_result(self, model: SGraph): + result = query(model, '@type=nonexistent_type_xyz') + paths = get_all_paths(result) + # Should return an (almost) empty model — no project elements + assert '/project/src/web/app.py' not in paths + + def test_attr_regex_py_files(self, model: SGraph): + result = query(model, '@name=~".*\\.py$"') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/src/db/models.py' in paths + # flask package does not end in .py + assert '/project/External/flask' not in paths + + def test_complex_query(self, model: SGraph): + # Files under src with loc > 100 but NOT in the web subdir + result = query(model, '"/project/src" AND @type=file AND @loc>100 AND NOT "/project/src/web"') + paths = get_all_paths(result) + assert '/project/src/db/models.py' in paths # loc=300 + assert '/project/src/db/queries.py' in paths # loc=150 + assert '/project/src/web/app.py' not in paths # excluded by NOT + assert '/project/src/common/utils.py' not in paths # loc=50 + + +# --------------------------------------------------------------------------- +# P2: Chain search and shortest path +# --------------------------------------------------------------------------- + +class TestParserP2: + """Parser tests for ---> and --- operators.""" + + def test_parse_chain_search(self): + expr = parse('"/a" ---> "/b"') + assert isinstance(expr, ChainSearchExpr) + assert isinstance(expr.from_expr, KeywordExpr) + assert expr.from_expr.keyword == '/a' + assert isinstance(expr.to_expr, KeywordExpr) + assert expr.to_expr.keyword == '/b' + + def test_parse_chain_search_with_type(self): + expr = parse('"/a" --import-> "/b"') + assert isinstance(expr, ChainSearchExpr) + assert expr.dep_type == 'import' + + def test_parse_chain_search_wildcard(self): + expr = parse('"*" ---> "/b"') + assert isinstance(expr, ChainSearchExpr) + assert expr.from_expr is None + + def test_parse_shortest_path(self): + expr = parse('"/a" --- "/b"') + assert isinstance(expr, ShortestPathExpr) + assert isinstance(expr.from_expr, KeywordExpr) + assert expr.from_expr.keyword == '/a' + assert isinstance(expr.to_expr, KeywordExpr) + assert expr.to_expr.keyword == '/b' + + def test_parse_shortest_path_wildcard(self): + expr = parse('"*" --- "/b"') + assert isinstance(expr, ShortestPathExpr) + assert expr.from_expr is None + + def test_precedence_shortest_before_chain(self): + # --- should be tried before ---> + expr = parse('"/a" --- "/b"') + assert isinstance(expr, ShortestPathExpr) + + def test_precedence_chain_before_dep(self): + # ---> should be parsed as chain, not dep search + expr = parse('"/a" ---> "/b"') + assert isinstance(expr, ChainSearchExpr) + + def test_dep_search_still_works(self): + # --> should still be dep search + expr = parse('"/a" --> "/b"') + assert isinstance(expr, DepSearchExpr) + + def test_undirected_dep_still_works(self): + # -- should still be dep search (not shortest path) + expr = parse('"/a" -- "/b"') + assert isinstance(expr, DepSearchExpr) + + +class TestEvaluatorP2: + """Evaluator tests for chain search and shortest path. + + Uses the same test model but the dependency chain is: + views.py --import--> app.py --import--> models.py + views.py --import--> app.py --import--> utils.py + views.py --import--> app.py --import--> flask + test_app.py --import--> app.py --import--> models.py + """ + + @pytest.fixture + def model(self) -> SGraph: + return create_test_model() + + # --- Chain search (--->) --- + + def test_chain_search_2hop(self, model: SGraph): + # views.py -> app.py -> models.py (2-hop chain) + result = query(model, '"/project/src/web/views.py" ---> "/project/src/db/models.py"') + assocs = get_all_associations(result) + paths = get_all_paths(result) + assert len(assocs) >= 2 # views->app and app->models + assert '/project/src/web/views.py' in paths + assert '/project/src/web/app.py' in paths # intermediate + assert '/project/src/db/models.py' in paths + + def test_chain_search_direct(self, model: SGraph): + # app.py -> models.py is also a 1-hop chain + result = query(model, '"/project/src/web/app.py" ---> "/project/src/db/models.py"') + assocs = get_all_associations(result) + assert len(assocs) >= 1 + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/src/db/models.py' in paths + + def test_chain_search_no_path(self, model: SGraph): + # No chain from models.py to views.py (wrong direction) + result = query(model, '"/project/src/db/models.py" ---> "/project/src/web/views.py"') + assocs = get_all_associations(result) + assert len(assocs) == 0 + + def test_chain_search_with_type_filter(self, model: SGraph): + # views.py --import--> app.py --import--> models.py (only import edges) + result = query(model, '"/project/src/web/views.py" --import-> "/project/src/db/models.py"') + assocs = get_all_associations(result) + deptypes = {a.deptype for a in assocs} + assert deptypes == {'import'} + assert len(assocs) >= 2 + + def test_chain_search_type_filter_blocks(self, model: SGraph): + # views.py has function_ref to queries.py, but app.py does not + # So chain via function_ref edges only won't reach models.py + result = query(model, '"/project/src/web/views.py" --function_ref-> "/project/src/db/models.py"') + assocs = get_all_associations(result) + assert len(assocs) == 0 + + def test_chain_search_subtree(self, model: SGraph): + # All chains from /web subtree to /db subtree + result = query(model, '"/project/src/web" ---> "/project/src/db"') + assocs = get_all_associations(result) + # Should find chains: views->app->models, views->app->queries (via function_ref... wait) + # Actually app->models (import), views->app->models (2 hops) + assert len(assocs) >= 1 + + def test_chain_search_wildcard_to(self, model: SGraph): + # All chains from test_app.py to anywhere — should find transitive deps + result = query(model, '"/project/test/test_app.py" ---> "*"') + assocs = get_all_associations(result) + # test_app -> app -> models, test_app -> app -> utils, test_app -> app -> flask + assert len(assocs) >= 2 + paths = get_all_paths(result) + assert '/project/test/test_app.py' in paths + assert '/project/src/web/app.py' in paths + + # --- Shortest path (---) --- + + def test_shortest_path_direct_neighbors(self, model: SGraph): + # app.py and models.py are directly connected + result = query(model, '"/project/src/web/app.py" --- "/project/src/db/models.py"') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths + assert '/project/src/db/models.py' in paths + assocs = get_all_associations(result) + assert len(assocs) >= 1 + + def test_shortest_path_2hop(self, model: SGraph): + # views.py -> app.py -> models.py (shortest = 2 hops) + result = query(model, '"/project/src/web/views.py" --- "/project/src/db/models.py"') + paths = get_all_paths(result) + assert '/project/src/web/views.py' in paths + assert '/project/src/db/models.py' in paths + # Intermediate app.py should be on the path + assert '/project/src/web/app.py' in paths + + def test_shortest_path_undirected(self, model: SGraph): + # models.py -> app.py is reverse direction, but --- is undirected + result = query(model, '"/project/src/db/models.py" --- "/project/src/web/app.py"') + paths = get_all_paths(result) + assert '/project/src/db/models.py' in paths + assert '/project/src/web/app.py' in paths + assocs = get_all_associations(result) + assert len(assocs) >= 1 + + def test_shortest_path_no_connection(self, model: SGraph): + # utils.py and queries.py have no connection (utils has no deps at all) + result = query(model, '"/project/src/common/utils.py" --- "/project/src/db/queries.py"') + # They ARE connected: views.py -> queries.py AND views.py -> app.py -> utils.py + # But this is multi-hop. If no path exists, result is empty. + # Actually: app.py -> utils.py (outgoing from app.py) + # and views.py -> queries.py (function_ref) + # So: queries.py <- views.py -> app.py -> utils.py (undirected path of length 3) + # Let's just check it doesn't crash + paths = get_all_paths(result) + # There should be a path (possibly long) + if paths: + assert '/project/src/common/utils.py' in paths + assert '/project/src/db/queries.py' in paths + + def test_shortest_path_same_element(self, model: SGraph): + # Trivial: element to itself + result = query(model, '"/project/src/web/app.py" --- "/project/src/web/app.py"') + paths = get_all_paths(result) + assert '/project/src/web/app.py' in paths From 83410a0132921d51657d9195d35c05ddf677173c Mon Sep 17 00:00:00 2001 From: Ville Laitila Date: Mon, 6 Apr 2026 21:16:35 +0300 Subject: [PATCH 2/4] feat(query): add max_depth parameter to chain search The chain search operator (--->) had a hardcoded depth cap of 20 hops, which prevented finding lineage chains deeper than that. Add a max_depth parameter to query() and evaluate() that plumbs through to the chain search DFS. Default remains 20 to preserve existing behavior. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/sgraph/query/engine.py | 5 +++-- src/sgraph/query/evaluator.py | 36 +++++++++++++++++++++++------------ 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/sgraph/query/engine.py b/src/sgraph/query/engine.py index d072f67..24d3cef 100644 --- a/src/sgraph/query/engine.py +++ b/src/sgraph/query/engine.py @@ -6,7 +6,7 @@ from sgraph.query.parser import parse -def query(model: SGraph, expression: str) -> SGraph: +def query(model: SGraph, expression: str, max_depth: int = 20) -> SGraph: """Execute an SGraph Query Language expression against a model. Parses *expression* into an AST and evaluates it against *model*, @@ -22,6 +22,7 @@ def query(model: SGraph, expression: str) -> SGraph: '@type=file AND @loc>500' '"/src/web" --> "/src/db"' '"/" AND NOT "/External"' + max_depth: Maximum hop count for chain search (``--->``). Defaults to 20. Returns: A new :class:`~sgraph.SGraph` with matching elements and their @@ -47,4 +48,4 @@ def query(model: SGraph, expression: str) -> SGraph: result = query(model, '"/" AND NOT "/External"') """ ast = parse(expression) - return evaluate(ast, model, total_model=model) + return evaluate(ast, model, total_model=model, max_depth=max_depth) diff --git a/src/sgraph/query/evaluator.py b/src/sgraph/query/evaluator.py index 5a619e7..6f76333 100644 --- a/src/sgraph/query/evaluator.py +++ b/src/sgraph/query/evaluator.py @@ -54,13 +54,19 @@ # Public entry point # --------------------------------------------------------------------------- -def evaluate(expr: Expression, model: SGraph, total_model: Optional[SGraph] = None) -> SGraph: +def evaluate( + expr: Expression, + model: SGraph, + total_model: Optional[SGraph] = None, + max_depth: int = 20, +) -> SGraph: """Evaluate *expr* against *model* and return a new filtered SGraph. Args: expr: Parsed AST node. model: Model to filter (accumulates through AND chains). total_model: Original unfiltered model for NOT complement. Defaults to *model*. + max_depth: Maximum hop count for chain search (``--->``). Defaults to 20. Returns: A new :class:`~sgraph.SGraph` with only matching elements. @@ -73,20 +79,20 @@ def evaluate(expr: Expression, model: SGraph, total_model: Optional[SGraph] = No if isinstance(expr, KeywordExpr): return _eval_keyword(expr, model) if isinstance(expr, AndExpr): - left = evaluate(expr.left, model, total) - return evaluate(expr.right, left, total) + left = evaluate(expr.left, model, total, max_depth=max_depth) + return evaluate(expr.right, left, total, max_depth=max_depth) if isinstance(expr, OrExpr): - left = evaluate(expr.left, model, total) - right = evaluate(expr.right, model, total) + left = evaluate(expr.left, model, total, max_depth=max_depth) + right = evaluate(expr.right, model, total, max_depth=max_depth) return _union(left, right) if isinstance(expr, NotExpr): return _eval_not(expr, model, total) if isinstance(expr, ParenExpr): - return evaluate(expr.inner, model, total) + return evaluate(expr.inner, model, total, max_depth=max_depth) if isinstance(expr, DepSearchExpr): return _eval_dep_search(expr, model, total) if isinstance(expr, ChainSearchExpr): - return _eval_chain_search(expr, model, total) + return _eval_chain_search(expr, model, total, max_depth=max_depth) if isinstance(expr, ShortestPathExpr): return _eval_shortest_path(expr, model, total) @@ -529,11 +535,17 @@ def _descendants(elements: list[SElement]) -> set[SElement]: # Chain search ( ---> ) # --------------------------------------------------------------------------- -_CHAIN_MAX_DEPTH = 20 - +def _eval_chain_search( + expr: ChainSearchExpr, + model: SGraph, + total: SGraph, + max_depth: int = 20, +) -> SGraph: + """Find all directed multi-hop chains FROM → ... → TO via DFS. -def _eval_chain_search(expr: ChainSearchExpr, model: SGraph, total: SGraph) -> SGraph: - """Find all directed multi-hop chains FROM → ... → TO via DFS.""" + *max_depth* caps the number of hops the DFS will follow from each + starting element. Defaults to 20. + """ from_originals = _resolve_endpoint(expr.from_expr, total) to_originals = _resolve_endpoint(expr.to_expr, total) is_wildcard_to = expr.to_expr is None @@ -565,7 +577,7 @@ def is_target(path: str) -> bool: return path in to_paths def dfs(elem: SElement, visited: set[str], chain: list[SElementAssociation], depth: int) -> None: - if depth > _CHAIN_MAX_DEPTH: + if depth > max_depth: return for a in elem.outgoing: if not edge_ok(a): From 5b68f98c7a8aea2904ee65b505c660424b83112e Mon Sep 17 00:00:00 2001 From: Ville Laitila Date: Tue, 7 Apr 2026 08:50:36 +0300 Subject: [PATCH 3/4] feat(query): return QueryResult with chains from query() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The query() function now returns a QueryResult dataclass instead of a raw SGraph. This bundles two pieces of information that callers previously had to compute themselves: result.subgraph – the filtered SGraph (same as before) result.chains – tuple of ordered SElement tuples discovered by chain searches (--->) For non-chain queries (filters, attribute matches), chains is the empty tuple — always present, never None. The chain elements are the ORIGINAL model instances (not subgraph copies), so callers can use ``elem is original`` identity checks and access full attrs without worrying about parallel object hierarchies. Internally _eval_chain_search already had the ordered chain in the DFS recursion (the ``chain`` parameter); it just didn't expose it. A new optional ``chain_collector`` parameter threads through evaluate() into the DFS, picking up the path tuple at every match. The chain_collector is intentionally NOT threaded through NotExpr branches: chains found inside a NOT are the EXCLUDED set, not the user-visible result, and surfacing them would mislead. BREAKING CHANGE: query() now returns QueryResult instead of SGraph. Callers that previously did ``query(...).rootNode`` must either - migrate to ``query(...).subgraph.rootNode``, or - continue calling sgraph.query.evaluator.evaluate() directly, which still returns SGraph for use as an internal helper. The 78 tests in tests/test_query.py adapted by updating only one shared helper function (``get_all_paths``) to read .subgraph if present. All 78 tests still pass. --- src/sgraph/query/__init__.py | 10 +++++-- src/sgraph/query/engine.py | 35 ++++++++++++++++-------- src/sgraph/query/evaluator.py | 47 +++++++++++++++++++++++++++----- src/sgraph/query/result.py | 50 +++++++++++++++++++++++++++++++++++ tests/test_query.py | 22 ++++++++++----- 5 files changed, 139 insertions(+), 25 deletions(-) create mode 100644 src/sgraph/query/result.py diff --git a/src/sgraph/query/__init__.py b/src/sgraph/query/__init__.py index 6e903bd..bfc8a8b 100644 --- a/src/sgraph/query/__init__.py +++ b/src/sgraph/query/__init__.py @@ -2,10 +2,16 @@ Public API:: - from sgraph.query import query + from sgraph.query import query, QueryResult result = query(model, '@type=file AND @loc>500') + elements = result.subgraph # filtered SGraph + # Chain searches also populate result.chains: + result = query(model, '"/leaf" ---> "/ancestor"', max_depth=30) + for chain in result.chains: + ... """ from sgraph.query.engine import query +from sgraph.query.result import QueryResult -__all__ = ['query'] +__all__ = ['query', 'QueryResult'] diff --git a/src/sgraph/query/engine.py b/src/sgraph/query/engine.py index 24d3cef..9ee2365 100644 --- a/src/sgraph/query/engine.py +++ b/src/sgraph/query/engine.py @@ -4,14 +4,16 @@ from sgraph import SGraph from sgraph.query.evaluator import evaluate from sgraph.query.parser import parse +from sgraph.query.result import QueryResult -def query(model: SGraph, expression: str, max_depth: int = 20) -> SGraph: +def query(model: SGraph, expression: str, max_depth: int = 20) -> QueryResult: """Execute an SGraph Query Language expression against a model. Parses *expression* into an AST and evaluates it against *model*, - returning a new SGraph containing only the matching elements and the - dependency edges connecting them. + returning a :class:`~sgraph.query.result.QueryResult` that bundles + the filtered sub-graph together with any ordered chains discovered + by chain-search expressions. The original *model* is never mutated. @@ -25,8 +27,12 @@ def query(model: SGraph, expression: str, max_depth: int = 20) -> SGraph: max_depth: Maximum hop count for chain search (``--->``). Defaults to 20. Returns: - A new :class:`~sgraph.SGraph` with matching elements and their - connecting associations. + A :class:`QueryResult` with two fields: + + - ``subgraph`` — the filtered :class:`~sgraph.SGraph` + - ``chains`` — tuple of ordered :class:`~sgraph.SElement` tuples + (one per discovered chain). Empty for queries that do not + contain a ``--->`` chain search. Raises: ValueError: If *expression* cannot be parsed. @@ -40,12 +46,19 @@ def query(model: SGraph, expression: str, max_depth: int = 20) -> SGraph: # All Python files with more than 500 lines result = query(model, '@type=file AND @loc>500') + for elem in result.subgraph.rootNode.children: + ... - # Dependencies from web module to db module - result = query(model, '"/src/web" --> "/src/db"') - - # Everything except external dependencies - result = query(model, '"/" AND NOT "/External"') + # Dependencies from web module to db module — chains is populated + result = query(model, '"/src/web" ---> "/src/db"', max_depth=10) + for chain in result.chains: + for elem in chain: + ... """ ast = parse(expression) - return evaluate(ast, model, total_model=model, max_depth=max_depth) + chains: list = [] + subgraph = evaluate( + ast, model, total_model=model, max_depth=max_depth, + chain_collector=chains, + ) + return QueryResult(subgraph=subgraph, chains=tuple(chains)) diff --git a/src/sgraph/query/evaluator.py b/src/sgraph/query/evaluator.py index 6f76333..6308ec4 100644 --- a/src/sgraph/query/evaluator.py +++ b/src/sgraph/query/evaluator.py @@ -59,6 +59,7 @@ def evaluate( model: SGraph, total_model: Optional[SGraph] = None, max_depth: int = 20, + chain_collector: Optional[list] = None, ) -> SGraph: """Evaluate *expr* against *model* and return a new filtered SGraph. @@ -67,6 +68,12 @@ def evaluate( model: Model to filter (accumulates through AND chains). total_model: Original unfiltered model for NOT complement. Defaults to *model*. max_depth: Maximum hop count for chain search (``--->``). Defaults to 20. + chain_collector: Optional list that receives one ordered tuple of + :class:`SElement` instances per discovered chain when the + expression contains a chain search (``--->``). Pass ``None`` + to skip path tracking. Chains discovered inside a NOT branch + are NOT collected (they would be the *excluded* set, not the + user-visible result). Returns: A new :class:`~sgraph.SGraph` with only matching elements. @@ -79,20 +86,29 @@ def evaluate( if isinstance(expr, KeywordExpr): return _eval_keyword(expr, model) if isinstance(expr, AndExpr): - left = evaluate(expr.left, model, total, max_depth=max_depth) - return evaluate(expr.right, left, total, max_depth=max_depth) + left = evaluate(expr.left, model, total, max_depth=max_depth, + chain_collector=chain_collector) + return evaluate(expr.right, left, total, max_depth=max_depth, + chain_collector=chain_collector) if isinstance(expr, OrExpr): - left = evaluate(expr.left, model, total, max_depth=max_depth) - right = evaluate(expr.right, model, total, max_depth=max_depth) + left = evaluate(expr.left, model, total, max_depth=max_depth, + chain_collector=chain_collector) + right = evaluate(expr.right, model, total, max_depth=max_depth, + chain_collector=chain_collector) return _union(left, right) if isinstance(expr, NotExpr): + # Chains inside a NOT are the EXCLUDED set; don't surface them. return _eval_not(expr, model, total) if isinstance(expr, ParenExpr): - return evaluate(expr.inner, model, total, max_depth=max_depth) + return evaluate(expr.inner, model, total, max_depth=max_depth, + chain_collector=chain_collector) if isinstance(expr, DepSearchExpr): return _eval_dep_search(expr, model, total) if isinstance(expr, ChainSearchExpr): - return _eval_chain_search(expr, model, total, max_depth=max_depth) + return _eval_chain_search( + expr, model, total, max_depth=max_depth, + chain_collector=chain_collector, + ) if isinstance(expr, ShortestPathExpr): return _eval_shortest_path(expr, model, total) @@ -540,11 +556,17 @@ def _eval_chain_search( model: SGraph, total: SGraph, max_depth: int = 20, + chain_collector: Optional[list] = None, ) -> SGraph: """Find all directed multi-hop chains FROM → ... → TO via DFS. *max_depth* caps the number of hops the DFS will follow from each starting element. Defaults to 20. + + *chain_collector*, when provided, receives one entry per discovered + chain. Each entry is a tuple of original-model :class:`SElement` + instances in start-to-end order. Callers that want only the + sub-graph (no path enumeration) pass ``None``. """ from_originals = _resolve_endpoint(expr.from_expr, total) to_originals = _resolve_endpoint(expr.to_expr, total) @@ -598,6 +620,19 @@ def dfs(elem: SElement, visited: set[str], chain: list[SElementAssociation], dep if aid not in found_assoc_ids: found_assoc_ids.add(aid) chain_assocs.append(a) + + # Record the ordered element tuple for this chain when + # the caller asked for it. The chain starts at the DFS + # root (``fe``) and ends at the matched ``target``. + if chain_collector is not None: + ordered = ( + (chain[0].fromElement,) if chain + else (a.fromElement,) + ) + ordered = ordered + tuple(c.toElement for c in chain) + ordered = ordered + (target,) + chain_collector.append(ordered) + # Continue DFS through this node too (more chains possible) new_visited = visited | {target_path} dfs(target, new_visited, chain + [a], depth + 1) diff --git a/src/sgraph/query/result.py b/src/sgraph/query/result.py new file mode 100644 index 0000000..670607e --- /dev/null +++ b/src/sgraph/query/result.py @@ -0,0 +1,50 @@ +"""QueryResult — wraps a query's sub-graph plus metadata. + +Returned from :func:`sgraph.query.query`. The :attr:`subgraph` field is +the same SGraph the query language has always produced (filtered +elements + relevant associations). The :attr:`chains` field is a tuple +of ordered element tuples — each chain is one DFS walk discovered by a +chain search expression. For non-chain queries (filters, attribute +matches, etc.) ``chains`` is the empty tuple. + +Example:: + + from sgraph.query import query + + result = query(model, '"/leaf" ---> "/ancestor"', max_depth=30) + + # Same as before — sub-graph access + for elem in result.subgraph.rootNode.children: + ... + + # New — ordered chains in a single field + for chain in result.chains: + # chain is tuple[SElement, ...] starting at "/leaf" + # and ending at "/ancestor" + for elem in chain: + ... +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sgraph.selement import SElement + from sgraph.sgraph import SGraph + + +@dataclass(frozen=True) +class QueryResult: + """Result of a query language evaluation. + + Attributes: + subgraph: The filtered sub-SGraph. Always present. + chains: Tuple of ordered element tuples discovered by chain + search expressions (``"a" ---> "b"``). Each chain is one + DFS walk in start-to-end order. For queries that do not + contain a chain search (filters, attribute matches), the + tuple is empty. + """ + subgraph: 'SGraph' + chains: tuple[tuple['SElement', ...], ...] = field(default_factory=tuple) diff --git a/tests/test_query.py b/tests/test_query.py index 52d23f8..246d674 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -130,21 +130,31 @@ def create_test_model() -> SGraph: # Helpers # --------------------------------------------------------------------------- -def get_all_paths(result_model: SGraph) -> set[str]: - """Collect every element path present in the result model.""" +def get_all_paths(result) -> set[str]: + """Collect every element path present in the query result. + + Accepts either a :class:`QueryResult` (the new public API) or a + raw :class:`SGraph` (used by some legacy tests that still call + :func:`evaluate` directly). + """ + sub = result.subgraph if hasattr(result, 'subgraph') else result paths: set[str] = set() - result_model.rootNode.traverseElements(lambda e: paths.add(e.getPath())) + sub.rootNode.traverseElements(lambda e: paths.add(e.getPath())) return paths -def get_all_associations(result_model: SGraph) -> list[SElementAssociation]: - """Collect all associations reachable from the result model's root.""" +def get_all_associations(result) -> list[SElementAssociation]: + """Collect all associations reachable from the query result's root. + + Accepts either a :class:`QueryResult` or a raw :class:`SGraph`. + """ + sub = result.subgraph if hasattr(result, 'subgraph') else result assocs: list[SElementAssociation] = [] def collect(e: SElement) -> None: assocs.extend(e.outgoing) - result_model.rootNode.traverseElements(collect) + sub.rootNode.traverseElements(collect) # Deduplicate by identity return list({id(a): a for a in assocs}.values()) From 724044448365ed9d6fc73d4d987031d6e96e71d2 Mon Sep 17 00:00:00 2001 From: Ville Laitila Date: Tue, 7 Apr 2026 15:31:29 +0300 Subject: [PATCH 4/4] feat(repr): add SGraph.__repr__ with bounded element count MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Default Python repr (``) is unhelpful in logs and obscures object identity when the same model is bound to several variables/dict keys. The new repr shows the root path, element count, and id, e.g.: Element count uses an iterative bounded walk (_REPR_COUNT_LIMIT = 10000), so repr() stays cheap on multi-million node models — important because debuggers, loggers, and exception formatters call repr() at arbitrary moments. The body is wrapped in try/except so a malformed model can never crash a logger. Also commits the existing untracked tests/mini_model.xml fixture that the new multi-root tests depend on, and updates the stale README example. --- README.md | 2 +- src/sgraph/sgraph.py | 48 ++++++++++++++++++++++ tests/mini_model.xml | 9 +++++ tests/sgraph_test.py | 96 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 tests/mini_model.xml diff --git a/README.md b/README.md index 13eaf8f..dedbe84 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Creating a simple model: >>> from sgraph import SElementAssociation >>> x = SGraph(SElement(None, '')) >>> x - + >>> x.to_deps(fname=None) diff --git a/src/sgraph/sgraph.py b/src/sgraph/sgraph.py index bd80595..af5f293 100644 --- a/src/sgraph/sgraph.py +++ b/src/sgraph/sgraph.py @@ -29,6 +29,24 @@ from .sgraph_utils import ParsingIntentionallyAborted, add_ea, find_assocs_between +def _bounded_descendant_count(root: SElement, limit: int) -> tuple[int, bool]: + """Count strict descendants of *root*, stopping at *limit*. + + Used by SGraph.__repr__ to keep repr() bounded on huge models. + Returns ``(count, hit_limit)`` where ``hit_limit`` is True iff the walk + was truncated. + """ + count = 0 + stack: list[SElement] = list(root.children) + while stack: + node = stack.pop() + count += 1 + if count >= limit: + return count, True + stack.extend(node.children) + return count, False + + class SGraph: rootNode: SElement # modelAttrs: dict[str, str] | dict[str, dict[str, str]] @@ -44,6 +62,36 @@ def __init__(self, root_node: SElement | None = None): self.propagateActions = [] self.totalModel = None + # repr() can be triggered by debuggers, exception formatters and any + # logger.debug('%r', graph) call. A full O(N) tree walk on multi-million + # node models would stall those code paths, so we cap the count. + _REPR_COUNT_LIMIT = 10000 + + def __repr__(self) -> str: + try: + children = self.rootNode.children + identity = hex(id(self)) + element_count, hit_limit = _bounded_descendant_count( + self.rootNode, self._REPR_COUNT_LIMIT) + count_str = f'{element_count}+' if hit_limit else str(element_count) + + if not children: + return f'' + if len(children) == 1: + root_name = children[0].name or '?' + return (f'') + names = ','.join((c.name or '?') for c in children[:3]) + if len(children) > 3: + names += ',...' + size_part = f' count={len(children)}' + else: + size_part = '' + return (f'') + except Exception as exc: # never let repr crash a logger + return f'' + def addPropagateAction(self, a: str, v: str): self.propagateActions.append((a, v)) diff --git a/tests/mini_model.xml b/tests/mini_model.xml new file mode 100644 index 0000000..b50d28e --- /dev/null +++ b/tests/mini_model.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/tests/sgraph_test.py b/tests/sgraph_test.py index b3aef37..51f330b 100644 --- a/tests/sgraph_test.py +++ b/tests/sgraph_test.py @@ -2,9 +2,11 @@ import os from typing import Any +from sgraph import SGraph from sgraph.loader import ModelLoader MODELFILE = 'modelfile.xml' +MINI_MODELFILE = 'mini_model.xml' # Helper for creating the model def get_model(file_name: str) -> Any: @@ -25,3 +27,97 @@ def test_deepcopy(): assert graph1.produce_deps_tuples() == graph2.produce_deps_tuples() assert graph1.calculate_model_stats() == graph2.calculate_model_stats() + +def test_repr_empty_model(): + """Empty SGraph should produce a useful repr, not the default .""" + graph = SGraph() + text = repr(graph) + assert text.startswith(' with element count.""" + graph = SGraph() + graph.createOrGetElementFromPath('/junit4/src/main/java/Foo.java') + graph.createOrGetElementFromPath('/junit4/src/main/java/Bar.java') + text = repr(graph) + assert text.startswith('')) + assert count == 6 + + +def test_repr_multi_root_lists_top_level_children(): + """Multiple top-level children should be visible (not collapsed to one).""" + graph = get_model(MINI_MODELFILE) + text = repr(graph) + assert text.startswith('