Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions lib/cypher/query.ex
Original file line number Diff line number Diff line change
Expand Up @@ -305,22 +305,34 @@ defmodule AshNeo4j.Cypher.Query do
atom(),
atom() | nil,
atom(),
boolean()
boolean(),
[{String.t(), any()}]
) :: t()
def aggregate_per_record(source_label, pk_field, ids, path_segments, kind, field, name, uniq? \\ false)
def aggregate_per_record(
source_label,
pk_field,
ids,
path_segments,
kind,
field,
name,
uniq? \\ false,
dest_conditions \\ []
)
when is_atom(pk_field) and is_list(ids) and is_list(path_segments) and is_atom(kind) do
path = build_agg_path(path_segments)
expr = aggregate_expr(kind, field, name, uniq?)
src = labels_string(source_label)
{dest_where, dest_params} = build_dest_conditions(dest_conditions)

%__MODULE__{
clauses: [
%Match{pattern: "(s:#{src})"},
%Where{conditions: ["s.#{pk_field} IN $agg_ids"]},
%OptionalMatch{pattern: "(s)#{path}"},
%Return{items: ["s.#{pk_field} AS source_id", expr]}
],
params: %{"agg_ids" => ids}
clauses:
[
%Match{pattern: "(s:#{src})"},
%Where{conditions: ["s.#{pk_field} IN $agg_ids"]},
%OptionalMatch{pattern: "(s)#{path}"}
] ++ dest_where ++ [%Return{items: ["s.#{pk_field} AS source_id", expr]}],
params: Map.merge(%{"agg_ids" => ids}, dest_params)
}
end

Expand All @@ -337,22 +349,34 @@ defmodule AshNeo4j.Cypher.Query do
atom(),
atom() | nil,
atom(),
boolean()
boolean(),
[{String.t(), any()}]
) :: t()
def aggregate_total(source_label, pk_field, ids, path_segments, kind, field, name, uniq? \\ false)
def aggregate_total(
source_label,
pk_field,
ids,
path_segments,
kind,
field,
name,
uniq? \\ false,
dest_conditions \\ []
)
when is_atom(pk_field) and is_list(ids) and is_list(path_segments) and is_atom(kind) do
path = build_agg_path(path_segments)
expr = aggregate_expr(kind, field, name, uniq?)
src = labels_string(source_label)
{dest_where, dest_params} = build_dest_conditions(dest_conditions)

%__MODULE__{
clauses: [
%Match{pattern: "(s:#{src})"},
%Where{conditions: ["s.#{pk_field} IN $agg_ids"]},
%OptionalMatch{pattern: "(s)#{path}"},
%Return{items: [expr]}
],
params: %{"agg_ids" => ids}
clauses:
[
%Match{pattern: "(s:#{src})"},
%Where{conditions: ["s.#{pk_field} IN $agg_ids"]},
%OptionalMatch{pattern: "(s)#{path}"}
] ++ dest_where ++ [%Return{items: [expr]}],
params: Map.merge(%{"agg_ids" => ids}, dest_params)
}
end

Expand Down Expand Up @@ -590,6 +614,20 @@ defmodule AshNeo4j.Cypher.Query do
"NOT (#{variable})#{rel}(:#{dest_label})"
end

defp build_dest_conditions([]), do: {[], %{}}

defp build_dest_conditions(dest_conditions) do
{cond_strings, params} =
dest_conditions
|> Enum.with_index()
|> Enum.reduce({[], %{}}, fn {{prop, val}, idx}, {parts, params} ->
key = "agg_filter_#{idx}"
{["d.#{prop} = $#{key}" | parts], Map.put(params, key, val)}
end)

{[%Where{conditions: Enum.reverse(cond_strings)}], params}
end

defp build_conditions(variable, conditions) do
conditions
|> Enum.with_index()
Expand Down
116 changes: 112 additions & 4 deletions lib/data_layer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1084,11 +1084,25 @@ defmodule AshNeo4j.DataLayer do
is_struct(aggregate.field, Ash.Query.Calculation) ->
run_expr_agg(mapping, neo4j_pk, ids, aggregate, mode, path_segments, dest_mapping)

# When a filter is present on a plain or embedded aggregate, load full
# destination records in Elixir so Ash.Filter.Runtime can evaluate it.
# Honouring the filter is a contract implied by can?({:aggregate, kind}).
# When a filter is present, try to push scalar == conditions into Cypher.
# Falls back to Elixir-side filtering for complex or embedded-field filters.
aggregate_has_filter?(aggregate) ->
run_filtered_aggregate(mapping, neo4j_pk, ids, aggregate, mode, path_segments, dest_mapping)
case {simple_agg_filter(aggregate, dest_mapping), embedded} do
{{:ok, dest_conditions}, nil} ->
run_simple_filtered_aggregate(
mapping,
neo4j_pk,
ids,
aggregate,
mode,
path_segments,
neo4j_field,
dest_conditions
)

_ ->
run_filtered_aggregate(mapping, neo4j_pk, ids, aggregate, mode, path_segments, dest_mapping)
end

embedded ->
{field_type, field_constraints} = embedded
Expand Down Expand Up @@ -1206,6 +1220,100 @@ defmodule AshNeo4j.DataLayer do
end
end

# Handles aggregates whose filter is a set of simple scalar == conditions that can be
# expressed as WHERE clauses in Cypher, avoiding full record loading in Elixir.
defp run_simple_filtered_aggregate(mapping, neo4j_pk, ids, aggregate, mode, path_segments, neo4j_field, dest_conditions) do
query =
case mode do
:per_record ->
CypherQuery.aggregate_per_record(
mapping.label_pair,
neo4j_pk,
ids,
path_segments,
aggregate.kind,
neo4j_field,
aggregate.name,
aggregate.uniq?,
dest_conditions
)

:total ->
CypherQuery.aggregate_total(
mapping.label_pair,
neo4j_pk,
ids,
path_segments,
aggregate.kind,
neo4j_field,
aggregate.name,
aggregate.uniq?,
dest_conditions
)
end

case Cypher.run(query) do
{:ok, %Bolty.Response{results: rows}} ->
case mode do
:per_record ->
{:ok,
Map.new(rows, fn row ->
{Map.get(row, "source_id"), Map.get(row, to_string(aggregate.name))}
end)}

:total ->
value = rows |> List.first(%{}) |> Map.get(to_string(aggregate.name), aggregate.default_value)
{:ok, value}
end

{:error, e} ->
{:error, e}
end
end

# Returns {:ok, [{prop_string, value}]} when the aggregate filter consists entirely of
# scalar == equality predicates on non-embedded destination attributes, enabling
# WHERE pushdown into Cypher. Returns :complex otherwise and falls back to Elixir-side filtering.
defp simple_agg_filter(aggregate, dest_mapping) do
filter = aggregate_query_filter(aggregate)

try do
simple = Ash.Filter.to_simple_filter(filter, skip_invalid?: false)
predicates = Map.get(simple, :predicates, [])

if Enum.empty?(predicates) do
:complex
else
Enum.reduce_while(predicates, {:ok, []}, fn predicate, {:ok, acc} ->
cond do
Map.get(predicate, :operator) != :== ->
{:halt, :complex}

not match?(%Ash.Query.Ref{}, Map.get(predicate, :left)) ->
{:halt, :complex}

match?(%Ash.Query.Calculation{}, Map.get(predicate.left, :attribute)) ->
{:halt, :complex}

true ->
attr_name = Ash.Query.Ref.name(predicate.left)

case embedded_field_type(dest_mapping.module, attr_name) do
nil ->
prop = Keyword.get(dest_mapping.properties, attr_name, attr_name) |> to_string()
{:cont, {:ok, acc ++ [{prop, predicate.right}]}}

_ ->
{:halt, :complex}
end
end
end)
end
rescue
_ -> :complex
end
end

# Extracts the aggregate's target field value from each record, respecting uniq?.
defp extract_aggregate_field_values(records, aggregate) do
values =
Expand Down
54 changes: 54 additions & 0 deletions test/aggregate_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,60 @@ defmodule AshNeo4j.AggregateTest do
end
end

describe "scalar filter pushdown (#253 — == filters pushed to Cypher WHERE)" do
test "count with integer == filter counts only matching records" do
author = create_author()
post1 = create_post(author, "post1")
post2 = create_post(author, "post2")
create_comment_with_score(post1, "a", 10)
create_comment_with_score(post1, "b", 5)
create_comment_with_score(post1, "c", 10)
create_comment_with_score(post2, "d", 5)

[p1, p2] = Post |> Ash.read!() |> Ash.load!([:high_score_count]) |> Enum.sort_by(& &1.title)
assert p1.high_score_count == 2
assert p2.high_score_count == 0
end

test "exists with integer == filter is false when no matching records" do
author = create_author()
post = create_post(author, "post")
create_comment_with_score(post, "a", 5)

[loaded] = Post |> Ash.read!() |> Ash.load!([:has_high_score])
assert loaded.has_high_score == false
end

test "exists with integer == filter is true when a matching record exists" do
author = create_author()
post = create_post(author, "post")
create_comment_with_score(post, "a", 5)
create_comment_with_score(post, "b", 10)

[loaded] = Post |> Ash.read!() |> Ash.load!([:has_high_score])
assert loaded.has_high_score == true
end

test "sum with integer == filter totals only matching records" do
author = create_author()
post = create_post(author, "post")
create_comment_with_score(post, "a", 10)
create_comment_with_score(post, "b", 5)
create_comment_with_score(post, "c", 10)

[loaded] = Post |> Ash.read!() |> Ash.load!([:high_score_total])
assert loaded.high_score_total == 20
end

test "count with integer == filter returns 0 for post with no comments" do
author = create_author()
create_post(author, "empty post")

[loaded] = Post |> Ash.read!() |> Ash.load!([:high_score_count])
assert loaded.high_score_count == 0
end
end

describe "aggregates on embedded struct fields" do
test "list aggregate returns deserialized typed structs" do
author = create_author()
Expand Down
13 changes: 13 additions & 0 deletions test/support/resource/post.ex
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ defmodule AshNeo4j.Test.Resource.Post do
list :alpha_comment_titles, :comments, :title do
filter expr(title == "alpha")
end

# Integer equality filters — used to verify #253 (scalar == pushed to Cypher WHERE).
count :high_score_count, :comments do
filter expr(score == 10)
end

exists :has_high_score, :comments do
filter expr(score == 10)
end

sum :high_score_total, :comments, :score do
filter expr(score == 10)
end
end

preparations do
Expand Down
Loading