Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion src/psqlgraph/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import not_, or_
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, aliased

from psqlgraph import ext

Expand Down Expand Up @@ -212,8 +212,122 @@ def path(self, *paths):
assert (
not self.entity().is_abstract_base()
), "Please narrow your search by specifying a node subclass"
# TODO: This is the original
for e in entities:
self = self.join(*getattr(self.entity(), e).attr)
Copy link
Copy Markdown
Contributor

@kulgan kulgan Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not obvious but its creating two join statements, one for the edge and the other for the destination node. Not clear if both joins will need an alias, but pretty sure the destination node needs one.
I played with this a little bit and was able to get the existing tests to pass

this = self
for e in entities:
  relation = getattr(self.entity(), e)
  alias = aliased(Node.get_subclass_named(relation.target_class.__dst_class__))
  this = self.join(relation.attr[0], relation.attr[1].of_type(alias))

I also had to make an update to self.entity() to cater for when the join_point is an AliasedClass not an entity.

# _joinpoint after first iteration: {'_joinpoint_entity': <Mapper at 0x1051643a0; Test>, 'prev': ((<Mapper at 0x105127670; Edge2>, <Mapper at 0x1051643a0; Test>, 'src'), {'_joinpoint_entity': <Mapper at 0x105127670; Edge2>, 'prev': ((<Mapper at 0x105116d60; Foo>, <Mapper at 0x105127670; Edge2>, '_Edge2_in'), {(<Mapper at 0x105116d60; Foo>, <Mapper at 0x105127670; Edge2>, '_Edge2_in'): {'_joinpoint_entity': <Mapper at 0x105127670; Edge2>, 'prev': (...), (...): {...}}}), (<Mapper at 0x105127670; Edge2>, <Mapper at 0x1051643a0; Test>, 'src'): {'_joinpoint_entity': <Mapper at 0x1051643a0; Test>, 'prev': ((<Mapper at 0x105127670; Edge2>, <Mapper at 0x1051643a0; Test>, 'src'), {'_joinpoint_entity': <Mapper at 0x105127670; Edge2>, 'prev': (...)})}})}
# breakpoint()

# TODO: always alias tables
# for e in entities:
# breakpoint()
# joined_tables = [mapper.class_ for mapper in self._join_entities]
# current_self_type = self.__class__.__name__
# curr_statement = str(self.statement)
# breakpoint()
# # self.entity() breaks when using aliased=true in join, something is going on with _joinpoint_zero.entity()
# curr_entity = self.entity()
# breakpoint()
# # aliased_entity = aliased(self.entity())
# # self = self.join(*getattr(aliased_entity, e).attr)
# self = self.join(*getattr(self.entity(), e).attr, aliased=True)
# # _joinpoint after first iteration: {'_joinpoint_entity': <AliasedClass at 0x10a010e20; Test>}
# breakpoint()

# TODO: attempt to alias using tuple input to join function
# # may change this to a list if it's not necessary to define alias name which requires keeping track of count
# alias_counts = {}
# for e in entities:
# breakpoint()
# curr_entity = self.entity()
# curr_association = getattr(self.entity(), e).attr
# breakpoint()
#
# if count := alias_counts.get(curr_association[1]):
# alias_counts[curr_association[1]] = count + 1
# # temp_alias = aliased(curr_association[1])
# # curr_association = (curr_association[0], aliased(curr_association[1])) # this breaks
# breakpoint()
# else:
# alias_counts[curr_association[1]] = 1
#
# breakpoint()
# # this join statement includes target as the first param and the param for the on clause as the second
# # self = self.join(*curr_association.attr)
# self = self.join(*curr_association)
# breakpoint()

# TODO: attempt to alias using target class in getattr(self.entity(), e)
# # may change this to a list if it's not necessary to define alias name which requires keeping track of count
# alias_counts = {}
# for e in entities:
# breakpoint()
# curr_entity = self.entity()
# curr_association = getattr(self.entity(), e)
# target_class = curr_association.target_class
# breakpoint()
#
# if count := alias_counts.get(target_class):
# alias_counts[target_class] = count + 1
# temp_alias = aliased(target_class)
# curr_association.target_class = temp_alias
# breakpoint()
# else:
# alias_counts[target_class] = 1
#
# check_val = curr_association.attr
# breakpoint()
# # this join statement includes target as the first param and the param for the on clause as the second
# self = self.join(*curr_association.attr)
# breakpoint()

# TODO: track the joined entities and create aliases when duplicates found
# for e in entities:
# joined_tables = [mapper.class_ for mapper in self._join_entities]
# breakpoint()
# curr_entity = self.entity()
# curr_association = getattr(self.entity(), e)
# target_class = curr_association.target_class
# owning_class = curr_association.owning_class
# breakpoint()
#
# if target_class in joined_tables:
# curr_association.target_class = aliased(target_class)
# breakpoint()
# if owning_class in joined_tables:
# curr_association.owning_class = aliased(owning_class)
# breakpoint()
#
# check_val = curr_association.attr
# breakpoint()
# # this join statement includes target as the first param and the param for the on clause as the second
# self = self.join(*curr_association.attr)
# joined_tables = [mapper.class_ for mapper in self._join_entities]
# breakpoint()

# TODO: attempt to track the entity and alias directly
# # may change this to a list if it's not necessary to define alias name which requires keeping track of count
# alias_counts = {}
# for e in entities:
# curr_entity = self.entity()
# breakpoint()
#
# if count := alias_counts.get(curr_entity):
# alias_counts[curr_entity] = count + 1
# # curr_entity = aliased(curr_entity)
# curr_entity = aliased(curr_entity, name="x") # see what naming it does
# breakpoint()
# else:
# alias_counts[curr_entity] = 1
#
# check_association = getattr(curr_entity, e)
# check_attribute_tuple = check_association.attr
#
# # this join statement includes target as the first param and the param for the on clause as the second
# self = self.join(*getattr(curr_entity, e).attr)
# breakpoint()

# breakpoint()
return self

def _get_link_details(self, entity, link_name):
Expand Down
9 changes: 9 additions & 0 deletions test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from test import PsqlgraphBaseTest, models

import pytest
from sqlalchemy.exc import ProgrammingError

from psqlgraph import PolyEdge, PolyNode

Expand Down Expand Up @@ -62,6 +63,14 @@ def test_path(self):
1,
)

def test_path_aliases(self):
"""Verify path with multiple joins to the same table does not produce an error."""
with self.g.session_scope():
try:
self.g.nodes(models.Foo).path("tests.foos.tests.foos").count()
except ProgrammingError as e:
assert False, f"Unexpected error occurred: {e}"

def test_subq_path_no_filter(self):
with self.g.session_scope():
self.assertEqual(
Expand Down