Skip to content

Commit 2e10142

Browse files
author
Peng Ren
committed
Clean up the code
1 parent 1332b0b commit 2e10142

File tree

4 files changed

+59
-178
lines changed

4 files changed

+59
-178
lines changed

pymongosql/cursor.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, TypeVar
44

5-
from pymongo.cursor import Cursor as MongoCursor
65
from pymongo.errors import PyMongoError
76

87
from .common import BaseCursor, CursorIterator
@@ -32,16 +31,15 @@ def __init__(self, connection: "Connection", **kwargs) -> None:
3231
self._result_set: Optional[ResultSet] = None
3332
self._result_set_class = ResultSet
3433
self._current_execution_plan: Optional[ExecutionPlan] = None
35-
self._mongo_cursor: Optional[MongoCursor] = None
3634
self._is_closed = False
3735

3836
@property
3937
def result_set(self) -> Optional[ResultSet]:
4038
return self._result_set
4139

4240
@result_set.setter
43-
def result_set(self, val: ResultSet) -> None:
44-
self._result_set = val
41+
def result_set(self, rs: ResultSet) -> None:
42+
self._result_set = rs
4543

4644
@property
4745
def has_result_set(self) -> bool:
@@ -52,8 +50,8 @@ def result_set_class(self) -> Optional[type]:
5250
return self._result_set_class
5351

5452
@result_set_class.setter
55-
def result_set_class(self, val: type) -> None:
56-
self._result_set_class = val
53+
def result_set_class(self, rs_cls: type) -> None:
54+
self._result_set_class = rs_cls
5755

5856
@property
5957
def rowcount(self) -> int:
@@ -107,7 +105,7 @@ def _execute_execution_plan(self, execution_plan: ExecutionPlan) -> None:
107105
# Build MongoDB find command
108106
find_command = {"find": execution_plan.collection, "filter": execution_plan.filter_stage or {}}
109107

110-
# Apply projection if specified (already in MongoDB format)
108+
# Apply projection if specified
111109
if execution_plan.projection_stage:
112110
find_command["projection"] = execution_plan.projection_stage
113111

@@ -236,15 +234,6 @@ def fetchall(self) -> List[Sequence[Any]]:
236234
def close(self) -> None:
237235
"""Close the cursor and free resources"""
238236
try:
239-
if self._mongo_cursor:
240-
# Close MongoDB cursor
241-
try:
242-
self._mongo_cursor.close()
243-
except Exception as e:
244-
_logger.warning(f"Error closing MongoDB cursor: {e}")
245-
finally:
246-
self._mongo_cursor = None
247-
248237
if self._result_set:
249238
# Close result set
250239
try:

pymongosql/sql/ast.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,20 @@ def visitLimitClause(self, ctx: PartiQLParser.LimitClauseContext) -> Any:
154154
except Exception as e:
155155
_logger.warning(f"Error processing LIMIT clause: {e}")
156156
return self.visitChildren(ctx)
157+
158+
def visitOffsetByClause(self, ctx: PartiQLParser.OffsetByClauseContext) -> Any:
159+
"""Handle OFFSET clause for result skipping"""
160+
_logger.debug("Processing OFFSET clause")
161+
try:
162+
if hasattr(ctx, "exprSelect") and ctx.exprSelect():
163+
offset_text = ctx.exprSelect().getText()
164+
try:
165+
offset_value = int(offset_text)
166+
self._parse_result.offset_value = offset_value
167+
_logger.debug(f"Extracted offset value: {offset_value}")
168+
except ValueError as e:
169+
_logger.warning(f"Invalid OFFSET value '{offset_text}': {e}")
170+
return self.visitChildren(ctx)
171+
except Exception as e:
172+
_logger.warning(f"Error processing OFFSET clause: {e}")
173+
return self.visitChildren(ctx)

tests/test_cursor.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import pytest
33

4-
from pymongosql.error import ProgrammingError
4+
from pymongosql.error import DatabaseError, ProgrammingError, SqlSyntaxError
55
from pymongosql.result_set import ResultSet
66

77

@@ -64,9 +64,7 @@ def test_execute_with_limit(self, conn):
6464
assert isinstance(cursor.result_set, ResultSet)
6565
rows = cursor.result_set.fetchall()
6666

67-
# Should return results from 22 users in dataset (LIMIT parsing may not be implemented yet)
68-
# TODO: Fix LIMIT parsing in SQL grammar
69-
assert len(rows) >= 1 # At least we get some results
67+
assert len(rows) == 2 # At least we get some results
7068

7169
# Check that names are present using DB API 2.0
7270
if len(rows) > 0:
@@ -85,7 +83,7 @@ def test_execute_with_skip(self, conn):
8583
rows = cursor.result_set.fetchall()
8684

8785
# Should return users after skipping 1 (from 22 users in dataset)
88-
assert len(rows) >= 0 # Could be 0-21 depending on implementation
86+
assert len(rows) == 21 # 22 - 1 = 21 users after skipping the first one
8987

9088
# Check that results have name field if any results using DB API 2.0
9189
if len(rows) > 0:
@@ -111,11 +109,8 @@ def test_execute_with_sort(self, conn):
111109
assert "name" in col_names
112110
assert all(len(row) >= 1 for row in rows) # All rows should have data
113111

114-
# Verify that we have actual user names from the dataset using DB API 2.0
115-
if "name" in col_names:
116-
name_idx = col_names.index("name")
117-
names = [row[name_idx] for row in rows]
118-
assert "John Doe" in names # First user from dataset
112+
# Verify that the first name in the result
113+
assert "Patricia Johnson" == rows[0][0]
119114

120115
def test_execute_complex_query(self, conn):
121116
"""Test executing complex query with multiple clauses"""
@@ -140,7 +135,7 @@ def test_execute_complex_query(self, conn):
140135

141136
def test_execute_nested_fields_query(self, conn):
142137
"""Test executing query with nested field access"""
143-
sql = "SELECT profile.bio, address.city, address.coordinates FROM users WHERE salary >= 100000"
138+
sql = "SELECT name, profile.bio, address.city FROM users WHERE salary >= 100000 ORDER BY salary DESC"
144139

145140
cursor = conn.cursor()
146141
result = cursor.execute(sql)
@@ -156,17 +151,20 @@ def test_execute_nested_fields_query(self, conn):
156151
if cursor.result_set.description:
157152
col_names = [desc[0] for desc in cursor.result_set.description]
158153
# Should include nested field names in projection
154+
assert "name" in col_names
159155
assert "profile.bio" in col_names
160156
assert "address.city" in col_names
161-
assert "address.coordinates" in col_names
157+
158+
# Verify the first record matched the highest salary
159+
assert "Patricia Johnson" == rows[0][0]
162160

163161
def test_execute_parser_error(self, conn):
164162
"""Test executing query with parser errors"""
165163
sql = "INVALID SQL SYNTAX"
166164

167165
# This should raise an exception due to invalid SQL
168166
cursor = conn.cursor()
169-
with pytest.raises(Exception): # Could be SqlSyntaxError or other parsing error
167+
with pytest.raises(SqlSyntaxError): # Could be SqlSyntaxError or other parsing error
170168
cursor.execute(sql)
171169

172170
def test_execute_database_error(self, conn, make_connection):
@@ -178,7 +176,7 @@ def test_execute_database_error(self, conn, make_connection):
178176

179177
# This should raise an exception due to closed connection
180178
cursor = conn.cursor()
181-
with pytest.raises(Exception): # Could be DatabaseError or OperationalError
179+
with pytest.raises(DatabaseError):
182180
cursor.execute(sql)
183181

184182
# Reconnect for other tests
@@ -188,31 +186,6 @@ def test_execute_database_error(self, conn, make_connection):
188186
finally:
189187
new_conn.close()
190188

191-
def test_execute_with_aliases(self, conn):
192-
"""Test executing query with field aliases"""
193-
sql = "SELECT name AS full_name, email AS user_email FROM users"
194-
cursor = conn.cursor()
195-
result = cursor.execute(sql)
196-
197-
assert result == cursor # execute returns self
198-
assert isinstance(cursor.result_set, ResultSet)
199-
rows = cursor.result_set.fetchall()
200-
201-
# Should return users with aliased field names
202-
assert len(rows) == 22
203-
204-
# Check that alias fields are present if aliasing works using DB API 2.0
205-
col_names = [desc[0] for desc in cursor.result_set.description]
206-
# Aliases might not work yet, so check for either original or alias names
207-
assert "name" in col_names or "full_name" in col_names
208-
# Check for email columns in description
209-
has_email = "email" in col_names or "user_email" in col_names
210-
for row in rows:
211-
assert len(row) >= 2 # Should have at least 2 columns
212-
# Verify we have email data if expected
213-
if has_email:
214-
assert True # Email column exists in description
215-
216189
def test_fetchone_without_execute(self, conn):
217190
"""Test fetchone without previous execute"""
218191
fresh_cursor = conn.cursor()
@@ -238,11 +211,10 @@ def test_fetchone_with_result(self, conn):
238211
# Execute query first
239212
cursor = conn.cursor()
240213
_ = cursor.execute(sql)
241-
242-
# Test fetchone - DB API 2.0 returns sequences, not dicts
243214
row = cursor.fetchone()
215+
244216
assert row is not None
245-
assert isinstance(row, (tuple, list)) # Should be sequence, not dict
217+
assert isinstance(row, (tuple, list))
246218
# Verify we have data using DB API 2.0 approach
247219
col_names = [desc[0] for desc in cursor.result_set.description] if cursor.result_set.description else []
248220
if "name" in col_names:

tests/test_sql_parser_general.py

Lines changed: 23 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,29 @@ def test_select_with_limit(self):
239239
assert execution_plan.limit_stage == 10
240240
assert execution_plan.projection_stage == {"name": 1}
241241

242+
def test_select_with_offset(self):
243+
"""Test SELECT with OFFSET clause"""
244+
sql = "SELECT name FROM users OFFSET 5"
245+
parser = SQLParser(sql)
246+
247+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
248+
execution_plan = parser.get_execution_plan()
249+
assert execution_plan.collection == "users"
250+
assert execution_plan.skip_stage == 5
251+
assert execution_plan.projection_stage == {"name": 1}
252+
253+
def test_select_with_limit_and_offset(self):
254+
"""Test SELECT with both LIMIT and OFFSET clauses"""
255+
sql = "SELECT name, email FROM users LIMIT 10 OFFSET 5"
256+
parser = SQLParser(sql)
257+
258+
assert not parser.has_errors, f"Parser errors: {parser.errors}"
259+
execution_plan = parser.get_execution_plan()
260+
assert execution_plan.collection == "users"
261+
assert execution_plan.limit_stage == 10
262+
assert execution_plan.skip_stage == 5
263+
assert execution_plan.projection_stage == {"name": 1, "email": 1}
264+
242265
def test_complex_query_combination(self):
243266
"""Test complex query with multiple clauses"""
244267
sql = """
@@ -277,83 +300,6 @@ def test_parser_error_handling(self):
277300
parser = SQLParser("INVALID SQL SYNTAX")
278301
parser.get_execution_plan()
279302

280-
def test_select_with_as_aliases(self):
281-
"""Test SELECT with AS aliases"""
282-
sql = "SELECT name AS username, email AS user_email FROM customers"
283-
parser = SQLParser(sql)
284-
285-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
286-
287-
execution_plan = parser.get_execution_plan()
288-
assert execution_plan.collection == "customers"
289-
assert execution_plan.filter_stage == {}
290-
assert execution_plan.projection_stage == {
291-
"name": 1,
292-
"email": 1,
293-
}
294-
295-
def test_select_with_mixed_aliases(self):
296-
"""Test SELECT with mixed alias formats"""
297-
sql = "SELECT name AS username, age user_age, status FROM users"
298-
parser = SQLParser(sql)
299-
300-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
301-
302-
execution_plan = parser.get_execution_plan()
303-
assert execution_plan.collection == "users"
304-
assert execution_plan.filter_stage == {}
305-
assert execution_plan.projection_stage == {
306-
"name": 1, # AS alias
307-
"age": 1, # Space-separated alias
308-
"status": 1, # No alias (field included)
309-
}
310-
311-
def test_select_with_space_separated_aliases(self):
312-
"""Test SELECT with space-separated aliases"""
313-
sql = "SELECT first_name fname, last_name lname, created_at creation_date FROM users"
314-
parser = SQLParser(sql)
315-
316-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
317-
318-
execution_plan = parser.get_execution_plan()
319-
assert execution_plan.collection == "users"
320-
assert execution_plan.filter_stage == {}
321-
assert execution_plan.projection_stage == {
322-
"first_name": 1,
323-
"last_name": 1,
324-
"created_at": 1,
325-
}
326-
327-
def test_select_with_complex_field_names_and_aliases(self):
328-
"""Test SELECT with complex field names and aliases"""
329-
sql = "SELECT user_profile.name AS display_name, account_settings.theme user_theme FROM users"
330-
parser = SQLParser(sql)
331-
332-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
333-
334-
execution_plan = parser.get_execution_plan()
335-
assert execution_plan.collection == "users"
336-
assert execution_plan.filter_stage == {}
337-
assert execution_plan.projection_stage == {
338-
"user_profile.name": 1,
339-
"account_settings.theme": 1,
340-
}
341-
342-
def test_select_function_with_aliases(self):
343-
"""Test SELECT with functions and aliases"""
344-
sql = "SELECT COUNT(*) AS total_count, MAX(age) max_age FROM users"
345-
parser = SQLParser(sql)
346-
347-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
348-
349-
execution_plan = parser.get_execution_plan()
350-
assert execution_plan.collection == "users"
351-
assert execution_plan.filter_stage == {}
352-
assert execution_plan.projection_stage == {
353-
"COUNT(*)": 1,
354-
"MAX(age)": 1,
355-
}
356-
357303
def test_superset_wrapped_subquery(self):
358304
"""Support Superset wrapping subquery with alias 'virtual'"""
359305
sql = "SELECT virtual.a, virtual.b FROM (SELECT a, b FROM users WHERE c = 1) virtual"
@@ -369,49 +315,6 @@ def test_superset_wrapped_subquery(self):
369315
# Inner filter should be preserved
370316
assert execution_plan.filter_stage == {"c": 1}
371317

372-
def test_select_single_field_with_alias(self):
373-
"""Test SELECT with single field and alias"""
374-
sql = "SELECT email AS contact_email FROM customers"
375-
parser = SQLParser(sql)
376-
377-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
378-
379-
execution_plan = parser.get_execution_plan()
380-
assert execution_plan.collection == "customers"
381-
assert execution_plan.filter_stage == {}
382-
assert execution_plan.projection_stage == {"email": 1}
383-
384-
def test_select_aliases_with_where_clause(self):
385-
"""Test SELECT with aliases and WHERE clause"""
386-
sql = "SELECT name AS username, status AS account_status FROM users WHERE age > 18"
387-
parser = SQLParser(sql)
388-
389-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
390-
391-
execution_plan = parser.get_execution_plan()
392-
assert execution_plan.collection == "users"
393-
assert execution_plan.filter_stage == {"age": {"$gt": 18}}
394-
assert execution_plan.projection_stage == {
395-
"name": 1,
396-
"status": 1,
397-
}
398-
399-
def test_select_case_insensitive_as_alias(self):
400-
"""Test SELECT with case insensitive AS keyword"""
401-
sql = "SELECT name as username, email As user_email, status AS account_status FROM users"
402-
parser = SQLParser(sql)
403-
404-
assert not parser.has_errors, f"Parser errors: {parser.errors}"
405-
406-
execution_plan = parser.get_execution_plan()
407-
assert execution_plan.collection == "users"
408-
assert execution_plan.filter_stage == {}
409-
assert execution_plan.projection_stage == {
410-
"name": 1,
411-
"email": 1,
412-
"status": 1,
413-
}
414-
415318
def test_different_collection_names(self):
416319
"""Test parsing with different collection names"""
417320
test_cases = [

0 commit comments

Comments
 (0)