Skip to content

Commit 5448bfa

Browse files
author
Peng Ren
committed
Support nested query
1 parent 1ed47a2 commit 5448bfa

File tree

7 files changed

+254
-7
lines changed

7 files changed

+254
-7
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,23 @@ while users:
137137
### SELECT Statements
138138
- Field selection: `SELECT name, age FROM users`
139139
- Wildcards: `SELECT * FROM products`
140+
- **Nested fields**: `SELECT profile.name, profile.age FROM users`
141+
- **Array access**: `SELECT items[0], items[1].name FROM orders`
140142

141143
### WHERE Clauses
142144
- Equality: `WHERE name = 'John'`
143145
- Comparisons: `WHERE age > 25`, `WHERE price <= 100.0`
144146
- Logical operators: `WHERE age > 18 AND status = 'active'`
147+
- **Nested field filtering**: `WHERE profile.status = 'active'`
148+
- **Array filtering**: `WHERE items[0].price > 100`
149+
150+
### Nested Field Support
151+
- **Single-level**: `profile.name`, `settings.theme`
152+
- **Multi-level**: `account.profile.name`, `config.database.host`
153+
- **Array access**: `items[0].name`, `orders[1].total`
154+
- **Complex queries**: `WHERE customer.profile.age > 18 AND orders[0].status = 'paid'`
155+
156+
> **Note**: Avoid SQL reserved words (`user`, `data`, `value`, `count`, etc.) as unquoted field names. Use alternatives or bracket notation for arrays.
145157
146158
### Sorting and Limiting
147159
- ORDER BY: `ORDER BY name ASC, age DESC`

pymongosql/result_set.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import Any, Dict, List, Optional, Sequence, Tuple
44

5+
import jmespath
56
from pymongo.cursor import Cursor as MongoCursor
67
from pymongo.errors import PyMongoError
78

@@ -125,14 +126,32 @@ def _process_document(self, doc: Dict[str, Any]) -> Dict[str, Any]:
125126
processed = {}
126127
for field_name, include_flag in self._execution_plan.projection_stage.items():
127128
if include_flag == 1: # Field is included in projection
128-
if field_name in doc:
129-
processed[field_name] = doc[field_name]
130-
elif field_name != "_id": # _id might be excluded by MongoDB
131-
# Field not found, set to None
132-
processed[field_name] = None
129+
# Use jmespath to handle nested field access (dot notation and array indexing)
130+
value = self._get_nested_value(doc, field_name)
131+
processed[field_name] = value
133132

134133
return processed
135134

135+
def _get_nested_value(self, doc: Dict[str, Any], field_path: str) -> Any:
136+
"""Extract nested field value from document using JMESPath
137+
138+
Supports:
139+
- Simple fields: "name" -> doc["name"]
140+
- Nested fields: "profile.bio" -> doc["profile"]["bio"]
141+
- Array indexing: "address.coordinates[1]" -> doc["address"]["coordinates"][1]
142+
- Wildcards: "items[*].name" -> [item["name"] for item in items]
143+
"""
144+
try:
145+
# Optimization: for simple field names without dots/brackets, use direct access
146+
if "." not in field_path and "[" not in field_path:
147+
return doc.get(field_path)
148+
149+
# Use jmespath for complex paths
150+
return jmespath.search(field_path, doc)
151+
except Exception as e:
152+
_logger.debug(f"Error extracting field '{field_path}': {e}")
153+
return None
154+
136155
def _dict_to_sequence(self, doc: Dict[str, Any]) -> Tuple[Any, ...]:
137156
"""Convert document dictionary to sequence according to column order"""
138157
if self._column_names is None:

pymongosql/sql/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "P
893893
return projection
894894

895895
def _extract_field_and_alias(self, item) -> Tuple[str, Optional[str]]:
896-
"""Extract field name and alias from projection item context"""
896+
"""Extract field name and alias from projection item context with nested field support"""
897897
if not hasattr(item, "children") or not item.children:
898898
return str(item), None
899899

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
antlr4-python3-runtime>=4.13.0
22
pymongo>=4.15.0
3+
jmespath>=1.0.0

tests/test_cursor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,28 @@ def test_execute_complex_query(self, conn):
138138
for row in rows:
139139
assert len(row) >= 2 # Should have at least name and email
140140

141+
def test_execute_nested_fields_query(self, conn):
142+
"""Test executing query with nested field access"""
143+
sql = "SELECT profile.bio, address.city, address.coordinates FROM users WHERE salary >= 100000"
144+
145+
cursor = conn.cursor()
146+
result = cursor.execute(sql)
147+
assert result == cursor
148+
assert isinstance(cursor.result_set, ResultSet)
149+
150+
# Get results - test nested field functionality
151+
rows = cursor.result_set.fetchall()
152+
assert isinstance(rows, list)
153+
assert len(rows) == 4
154+
155+
# Verify that nested fields are properly projected
156+
if cursor.result_set.description:
157+
col_names = [desc[0] for desc in cursor.result_set.description]
158+
# Should include nested field names in projection
159+
assert "profile.bio" in col_names
160+
assert "address.city" in col_names
161+
assert "address.coordinates" in col_names
162+
141163
def test_execute_parser_error(self, conn):
142164
"""Test executing query with parser errors"""
143165
sql = "INVALID SQL SYNTAX"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pymongosql.sql.parser import SQLParser
66

77

8-
class TestSQLParser:
8+
class TestSQLParserGeneral:
99
"""Comprehensive test suite for SQL parser from simple to complex queries"""
1010

1111
def test_simple_select_all(self):
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Comprehensive tests for nested field support in PyMongoSQL
4+
"""
5+
import pytest
6+
7+
from pymongosql.error import SqlSyntaxError
8+
from pymongosql.sql.parser import SQLParser
9+
10+
11+
class TestSQLParserNestedFields:
12+
"""Test suite for nested field querying functionality"""
13+
14+
def test_basic_single_level_nesting_select(self):
15+
"""Test basic single-level nested fields in SELECT"""
16+
sql = "SELECT c.a, c.b FROM collection"
17+
parser = SQLParser(sql)
18+
19+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
20+
21+
execution_plan = parser.get_execution_plan()
22+
assert execution_plan.collection == "collection"
23+
assert execution_plan.projection_stage == {"c.a": 1, "c.b": 1}
24+
assert execution_plan.filter_stage == {}
25+
26+
def test_basic_single_level_nesting_where(self):
27+
"""Test basic single-level nested fields in WHERE clause"""
28+
sql = "SELECT * FROM users WHERE profile.status = 'active'"
29+
parser = SQLParser(sql)
30+
31+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
32+
33+
execution_plan = parser.get_execution_plan()
34+
assert execution_plan.collection == "users"
35+
assert execution_plan.filter_stage == {"profile.status": "active"}
36+
37+
def test_multi_level_nesting_non_reserved_words(self):
38+
"""Test multi-level nested fields with non-reserved words"""
39+
sql = "SELECT account.profile.name FROM users WHERE account.settings.theme = 'dark'"
40+
parser = SQLParser(sql)
41+
42+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
43+
44+
execution_plan = parser.get_execution_plan()
45+
assert execution_plan.collection == "users"
46+
assert execution_plan.projection_stage == {"account.profile.name": 1}
47+
assert execution_plan.filter_stage == {"account.settings.theme": "dark"}
48+
49+
def test_array_bracket_notation_select(self):
50+
"""Test array access using bracket notation in SELECT"""
51+
sql = "SELECT items[0], items[1].name FROM orders"
52+
parser = SQLParser(sql)
53+
54+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
55+
56+
execution_plan = parser.get_execution_plan()
57+
assert execution_plan.collection == "orders"
58+
assert execution_plan.projection_stage == {"items[0]": 1, "items[1].name": 1}
59+
60+
def test_array_bracket_notation_where(self):
61+
"""Test array access using bracket notation in WHERE"""
62+
sql = "SELECT * FROM orders WHERE items[0].price > 100"
63+
parser = SQLParser(sql)
64+
65+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
66+
67+
execution_plan = parser.get_execution_plan()
68+
assert execution_plan.collection == "orders"
69+
assert execution_plan.filter_stage == {"items[0].price": {"$gt": 100}}
70+
71+
def test_quoted_reserved_words(self):
72+
"""Test using quoted reserved words as field names - currently limited support"""
73+
# Note: This test documents current limitations with quoted identifiers in complex paths
74+
sql = 'SELECT "user" FROM collection' # Simplified test that works
75+
parser = SQLParser(sql)
76+
77+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
78+
79+
execution_plan = parser.get_execution_plan()
80+
assert execution_plan.collection == "collection"
81+
assert execution_plan.projection_stage == {'"user"': 1}
82+
83+
def test_complex_nested_query(self):
84+
"""Test complex query with multiple nested field types"""
85+
sql = """
86+
SELECT
87+
customer.profile.name,
88+
orders[0].total,
89+
settings.preferences.theme
90+
FROM transactions
91+
WHERE customer.profile.age > 18
92+
AND orders[0].status = 'completed'
93+
AND settings.notifications = true
94+
"""
95+
parser = SQLParser(sql)
96+
97+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
98+
99+
execution_plan = parser.get_execution_plan()
100+
assert execution_plan.collection == "transactions"
101+
102+
expected_projection = {"customer.profile.name": 1, "orders[0].total": 1, "settings.preferences.theme": 1}
103+
assert execution_plan.projection_stage == expected_projection
104+
105+
# The filter should be a combination of conditions
106+
expected_filter = {
107+
"$and": [
108+
{"customer.profile.age": {"$gt": 18}},
109+
{"orders[0].status": "completed"},
110+
{"settings.notifications": True},
111+
]
112+
}
113+
assert execution_plan.filter_stage == expected_filter
114+
115+
def test_reserved_word_user_fails(self):
116+
"""Test that unquoted 'user' keyword fails"""
117+
sql = "SELECT user.profile.name FROM users"
118+
119+
with pytest.raises(SqlSyntaxError) as exc_info:
120+
parser = SQLParser(sql)
121+
parser.get_execution_plan()
122+
123+
assert "no viable alternative" in str(exc_info.value)
124+
125+
def test_reserved_word_value_fails(self):
126+
"""Test that unquoted 'value' keyword fails"""
127+
sql = "SELECT data.value FROM items"
128+
129+
with pytest.raises(SqlSyntaxError) as exc_info:
130+
parser = SQLParser(sql)
131+
parser.get_execution_plan()
132+
133+
assert "no viable alternative" in str(exc_info.value)
134+
135+
def test_numeric_dot_notation_fails(self):
136+
"""Test that numeric dot notation fails"""
137+
sql = "SELECT c.0.name FROM collection"
138+
139+
with pytest.raises(SqlSyntaxError) as exc_info:
140+
parser = SQLParser(sql)
141+
parser.get_execution_plan()
142+
143+
assert "mismatched input" in str(exc_info.value)
144+
145+
def test_nested_with_comparison_operators(self):
146+
"""Test nested fields with various comparison operators"""
147+
# Test supported comparison operators with non-reserved field names
148+
test_cases = [
149+
("profile.age > 18", {"profile.age": {"$gt": 18}}),
150+
("settings.total < 100", {"settings.total": {"$lt": 100}}), # Changed from 'count' (reserved)
151+
("status.active = true", {"status.active": True}),
152+
("config.name != 'default'", {"config.name": {"$ne": "default"}}),
153+
]
154+
155+
for where_clause, expected_filter in test_cases:
156+
sql = f"SELECT * FROM collection WHERE {where_clause}"
157+
parser = SQLParser(sql)
158+
159+
assert not parser.has_errors, f"Parser errors for '{where_clause}': {parser.errors}"
160+
161+
execution_plan = parser.get_execution_plan()
162+
assert execution_plan.filter_stage == expected_filter
163+
164+
def test_nested_with_logical_operators(self):
165+
"""Test nested fields with logical operators"""
166+
sql = """
167+
SELECT * FROM users
168+
WHERE profile.age > 18
169+
AND settings.active = true
170+
OR profile.vip = true
171+
"""
172+
parser = SQLParser(sql)
173+
174+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
175+
176+
execution_plan = parser.get_execution_plan()
177+
# The exact structure depends on operator precedence handling
178+
assert "profile.age" in str(execution_plan.filter_stage)
179+
assert "settings.active" in str(execution_plan.filter_stage)
180+
assert "profile.vip" in str(execution_plan.filter_stage)
181+
182+
def test_nested_with_aliases(self):
183+
"""Test nested fields with column aliases"""
184+
sql = "SELECT profile.name AS fullname, settings.theme AS ui_theme FROM users"
185+
parser = SQLParser(sql)
186+
187+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
188+
189+
execution_plan = parser.get_execution_plan()
190+
assert execution_plan.collection == "users"
191+
# Note: Current implementation uses original field names in projection
192+
# Aliases are handled at the result processing level
193+
assert execution_plan.projection_stage == {"profile.name": 1, "settings.theme": 1}

0 commit comments

Comments
 (0)