diff --git a/lib/cypher/query.ex b/lib/cypher/query.ex index 3b42fb7..a200ec8 100644 --- a/lib/cypher/query.ex +++ b/lib/cypher/query.ex @@ -305,22 +305,34 @@ defmodule AshNeo4j.Cypher.Query do atom(), atom() | nil, atom(), - boolean() + boolean(), + [{String.t(), any()}] ) :: t() - def aggregate_per_record(source_label, pk_field, ids, path_segments, kind, field, name, uniq? \\ false) + def aggregate_per_record( + source_label, + pk_field, + ids, + path_segments, + kind, + field, + name, + uniq? \\ false, + dest_conditions \\ [] + ) when is_atom(pk_field) and is_list(ids) and is_list(path_segments) and is_atom(kind) do path = build_agg_path(path_segments) expr = aggregate_expr(kind, field, name, uniq?) src = labels_string(source_label) + {dest_where, dest_params} = build_dest_conditions(dest_conditions) %__MODULE__{ - clauses: [ - %Match{pattern: "(s:#{src})"}, - %Where{conditions: ["s.#{pk_field} IN $agg_ids"]}, - %OptionalMatch{pattern: "(s)#{path}"}, - %Return{items: ["s.#{pk_field} AS source_id", expr]} - ], - params: %{"agg_ids" => ids} + clauses: + [ + %Match{pattern: "(s:#{src})"}, + %Where{conditions: ["s.#{pk_field} IN $agg_ids"]}, + %OptionalMatch{pattern: "(s)#{path}"} + ] ++ dest_where ++ [%Return{items: ["s.#{pk_field} AS source_id", expr]}], + params: Map.merge(%{"agg_ids" => ids}, dest_params) } end @@ -337,22 +349,34 @@ defmodule AshNeo4j.Cypher.Query do atom(), atom() | nil, atom(), - boolean() + boolean(), + [{String.t(), any()}] ) :: t() - def aggregate_total(source_label, pk_field, ids, path_segments, kind, field, name, uniq? \\ false) + def aggregate_total( + source_label, + pk_field, + ids, + path_segments, + kind, + field, + name, + uniq? \\ false, + dest_conditions \\ [] + ) when is_atom(pk_field) and is_list(ids) and is_list(path_segments) and is_atom(kind) do path = build_agg_path(path_segments) expr = aggregate_expr(kind, field, name, uniq?) src = labels_string(source_label) + {dest_where, dest_params} = build_dest_conditions(dest_conditions) %__MODULE__{ - clauses: [ - %Match{pattern: "(s:#{src})"}, - %Where{conditions: ["s.#{pk_field} IN $agg_ids"]}, - %OptionalMatch{pattern: "(s)#{path}"}, - %Return{items: [expr]} - ], - params: %{"agg_ids" => ids} + clauses: + [ + %Match{pattern: "(s:#{src})"}, + %Where{conditions: ["s.#{pk_field} IN $agg_ids"]}, + %OptionalMatch{pattern: "(s)#{path}"} + ] ++ dest_where ++ [%Return{items: [expr]}], + params: Map.merge(%{"agg_ids" => ids}, dest_params) } end @@ -590,6 +614,20 @@ defmodule AshNeo4j.Cypher.Query do "NOT (#{variable})#{rel}(:#{dest_label})" end + defp build_dest_conditions([]), do: {[], %{}} + + defp build_dest_conditions(dest_conditions) do + {cond_strings, params} = + dest_conditions + |> Enum.with_index() + |> Enum.reduce({[], %{}}, fn {{prop, val}, idx}, {parts, params} -> + key = "agg_filter_#{idx}" + {["d.#{prop} = $#{key}" | parts], Map.put(params, key, val)} + end) + + {[%Where{conditions: Enum.reverse(cond_strings)}], params} + end + defp build_conditions(variable, conditions) do conditions |> Enum.with_index() diff --git a/lib/data_layer.ex b/lib/data_layer.ex index 783c6da..25b13a7 100644 --- a/lib/data_layer.ex +++ b/lib/data_layer.ex @@ -1084,11 +1084,25 @@ defmodule AshNeo4j.DataLayer do is_struct(aggregate.field, Ash.Query.Calculation) -> run_expr_agg(mapping, neo4j_pk, ids, aggregate, mode, path_segments, dest_mapping) - # When a filter is present on a plain or embedded aggregate, load full - # destination records in Elixir so Ash.Filter.Runtime can evaluate it. - # Honouring the filter is a contract implied by can?({:aggregate, kind}). + # When a filter is present, try to push scalar == conditions into Cypher. + # Falls back to Elixir-side filtering for complex or embedded-field filters. aggregate_has_filter?(aggregate) -> - run_filtered_aggregate(mapping, neo4j_pk, ids, aggregate, mode, path_segments, dest_mapping) + case {simple_agg_filter(aggregate, dest_mapping), embedded} do + {{:ok, dest_conditions}, nil} -> + run_simple_filtered_aggregate( + mapping, + neo4j_pk, + ids, + aggregate, + mode, + path_segments, + neo4j_field, + dest_conditions + ) + + _ -> + run_filtered_aggregate(mapping, neo4j_pk, ids, aggregate, mode, path_segments, dest_mapping) + end embedded -> {field_type, field_constraints} = embedded @@ -1206,6 +1220,100 @@ defmodule AshNeo4j.DataLayer do end end + # Handles aggregates whose filter is a set of simple scalar == conditions that can be + # expressed as WHERE clauses in Cypher, avoiding full record loading in Elixir. + defp run_simple_filtered_aggregate(mapping, neo4j_pk, ids, aggregate, mode, path_segments, neo4j_field, dest_conditions) do + query = + case mode do + :per_record -> + CypherQuery.aggregate_per_record( + mapping.label_pair, + neo4j_pk, + ids, + path_segments, + aggregate.kind, + neo4j_field, + aggregate.name, + aggregate.uniq?, + dest_conditions + ) + + :total -> + CypherQuery.aggregate_total( + mapping.label_pair, + neo4j_pk, + ids, + path_segments, + aggregate.kind, + neo4j_field, + aggregate.name, + aggregate.uniq?, + dest_conditions + ) + end + + case Cypher.run(query) do + {:ok, %Bolty.Response{results: rows}} -> + case mode do + :per_record -> + {:ok, + Map.new(rows, fn row -> + {Map.get(row, "source_id"), Map.get(row, to_string(aggregate.name))} + end)} + + :total -> + value = rows |> List.first(%{}) |> Map.get(to_string(aggregate.name), aggregate.default_value) + {:ok, value} + end + + {:error, e} -> + {:error, e} + end + end + + # Returns {:ok, [{prop_string, value}]} when the aggregate filter consists entirely of + # scalar == equality predicates on non-embedded destination attributes, enabling + # WHERE pushdown into Cypher. Returns :complex otherwise and falls back to Elixir-side filtering. + defp simple_agg_filter(aggregate, dest_mapping) do + filter = aggregate_query_filter(aggregate) + + try do + simple = Ash.Filter.to_simple_filter(filter, skip_invalid?: false) + predicates = Map.get(simple, :predicates, []) + + if Enum.empty?(predicates) do + :complex + else + Enum.reduce_while(predicates, {:ok, []}, fn predicate, {:ok, acc} -> + cond do + Map.get(predicate, :operator) != :== -> + {:halt, :complex} + + not match?(%Ash.Query.Ref{}, Map.get(predicate, :left)) -> + {:halt, :complex} + + match?(%Ash.Query.Calculation{}, Map.get(predicate.left, :attribute)) -> + {:halt, :complex} + + true -> + attr_name = Ash.Query.Ref.name(predicate.left) + + case embedded_field_type(dest_mapping.module, attr_name) do + nil -> + prop = Keyword.get(dest_mapping.properties, attr_name, attr_name) |> to_string() + {:cont, {:ok, acc ++ [{prop, predicate.right}]}} + + _ -> + {:halt, :complex} + end + end + end) + end + rescue + _ -> :complex + end + end + # Extracts the aggregate's target field value from each record, respecting uniq?. defp extract_aggregate_field_values(records, aggregate) do values = diff --git a/test/aggregate_test.exs b/test/aggregate_test.exs index 5958d00..a9706ce 100644 --- a/test/aggregate_test.exs +++ b/test/aggregate_test.exs @@ -266,6 +266,60 @@ defmodule AshNeo4j.AggregateTest do end end + describe "scalar filter pushdown (#253 — == filters pushed to Cypher WHERE)" do + test "count with integer == filter counts only matching records" do + author = create_author() + post1 = create_post(author, "post1") + post2 = create_post(author, "post2") + create_comment_with_score(post1, "a", 10) + create_comment_with_score(post1, "b", 5) + create_comment_with_score(post1, "c", 10) + create_comment_with_score(post2, "d", 5) + + [p1, p2] = Post |> Ash.read!() |> Ash.load!([:high_score_count]) |> Enum.sort_by(& &1.title) + assert p1.high_score_count == 2 + assert p2.high_score_count == 0 + end + + test "exists with integer == filter is false when no matching records" do + author = create_author() + post = create_post(author, "post") + create_comment_with_score(post, "a", 5) + + [loaded] = Post |> Ash.read!() |> Ash.load!([:has_high_score]) + assert loaded.has_high_score == false + end + + test "exists with integer == filter is true when a matching record exists" do + author = create_author() + post = create_post(author, "post") + create_comment_with_score(post, "a", 5) + create_comment_with_score(post, "b", 10) + + [loaded] = Post |> Ash.read!() |> Ash.load!([:has_high_score]) + assert loaded.has_high_score == true + end + + test "sum with integer == filter totals only matching records" do + author = create_author() + post = create_post(author, "post") + create_comment_with_score(post, "a", 10) + create_comment_with_score(post, "b", 5) + create_comment_with_score(post, "c", 10) + + [loaded] = Post |> Ash.read!() |> Ash.load!([:high_score_total]) + assert loaded.high_score_total == 20 + end + + test "count with integer == filter returns 0 for post with no comments" do + author = create_author() + create_post(author, "empty post") + + [loaded] = Post |> Ash.read!() |> Ash.load!([:high_score_count]) + assert loaded.high_score_count == 0 + end + end + describe "aggregates on embedded struct fields" do test "list aggregate returns deserialized typed structs" do author = create_author() diff --git a/test/support/resource/post.ex b/test/support/resource/post.ex index 349b78d..4efd53e 100644 --- a/test/support/resource/post.ex +++ b/test/support/resource/post.ex @@ -101,6 +101,19 @@ defmodule AshNeo4j.Test.Resource.Post do list :alpha_comment_titles, :comments, :title do filter expr(title == "alpha") end + + # Integer equality filters — used to verify #253 (scalar == pushed to Cypher WHERE). + count :high_score_count, :comments do + filter expr(score == 10) + end + + exists :has_high_score, :comments do + filter expr(score == 10) + end + + sum :high_score_total, :comments, :score do + filter expr(score == 10) + end end preparations do