Skip to content

Commit 47cf574

Browse files
author
Peng Ren
committed
Add more cases for query
1 parent 88d5e20 commit 47cf574

File tree

8 files changed

+604
-391
lines changed

8 files changed

+604
-391
lines changed

pymongosql/cursor.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .common import BaseCursor, CursorIterator
99
from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError
1010
from .result_set import ResultSet
11-
from .sql.builder import QueryPlan
11+
from .sql.builder import ExecutionPlan
1212
from .sql.parser import SQLParser
1313

1414
if TYPE_CHECKING:
@@ -31,7 +31,7 @@ def __init__(self, connection: "Connection", **kwargs) -> None:
3131
self._kwargs = kwargs
3232
self._result_set: Optional[ResultSet] = None
3333
self._result_set_class = ResultSet
34-
self._current_query_plan: Optional[QueryPlan] = None
34+
self._current_execution_plan: Optional[ExecutionPlan] = None
3535
self._mongo_cursor: Optional[MongoCursor] = None
3636
self._is_closed = False
3737

@@ -78,65 +78,66 @@ def _check_closed(self) -> None:
7878
if self._is_closed:
7979
raise ProgrammingError("Cursor is closed")
8080

81-
def _parse_sql(self, sql: str) -> QueryPlan:
82-
"""Parse SQL statement and return QueryPlan"""
81+
def _parse_sql(self, sql: str) -> ExecutionPlan:
82+
"""Parse SQL statement and return ExecutionPlan"""
8383
try:
8484
parser = SQLParser(sql)
85-
query_plan = parser.get_query_plan()
85+
execution_plan = parser.get_execution_plan()
8686

87-
if not query_plan.validate():
87+
if not execution_plan.validate():
8888
raise SqlSyntaxError("Generated query plan is invalid")
8989

90-
return query_plan
90+
return execution_plan
9191

9292
except SqlSyntaxError:
9393
raise
9494
except Exception as e:
9595
_logger.error(f"SQL parsing failed: {e}")
9696
raise SqlSyntaxError(f"Failed to parse SQL: {e}")
9797

98-
def _execute_query_plan(self, query_plan: QueryPlan) -> None:
99-
"""Execute a QueryPlan against MongoDB using db.command"""
98+
def _execute_execution_plan(self, execution_plan: ExecutionPlan) -> None:
99+
"""Execute an ExecutionPlan against MongoDB using db.command"""
100100
try:
101101
# Get database
102-
if not query_plan.collection:
102+
if not execution_plan.collection:
103103
raise ProgrammingError("No collection specified in query")
104104

105105
db = self.connection.database
106106

107107
# Build MongoDB find command
108-
find_command = {"find": query_plan.collection, "filter": query_plan.filter_stage or {}}
108+
find_command = {"find": execution_plan.collection, "filter": execution_plan.filter_stage or {}}
109109

110-
# Convert projection stage from alias mapping to MongoDB format
111-
if query_plan.projection_stage:
112-
# Convert {"field": "alias"} to {"field": 1} for MongoDB
113-
find_command["projection"] = {field: 1 for field in query_plan.projection_stage.keys()}
110+
# Apply projection if specified (already in MongoDB format)
111+
if execution_plan.projection_stage:
112+
find_command["projection"] = execution_plan.projection_stage
114113

115114
# Apply sort if specified
116-
if query_plan.sort_stage:
115+
if execution_plan.sort_stage:
117116
sort_spec = {}
118-
for sort_dict in query_plan.sort_stage:
117+
for sort_dict in execution_plan.sort_stage:
119118
for field, direction in sort_dict.items():
120119
sort_spec[field] = direction
121120
find_command["sort"] = sort_spec
122121

123122
# Apply skip if specified
124-
if query_plan.skip_stage:
125-
find_command["skip"] = query_plan.skip_stage
123+
if execution_plan.skip_stage:
124+
find_command["skip"] = execution_plan.skip_stage
126125

127126
# Apply limit if specified
128-
if query_plan.limit_stage:
129-
find_command["limit"] = query_plan.limit_stage
127+
if execution_plan.limit_stage:
128+
find_command["limit"] = execution_plan.limit_stage
130129

131130
_logger.debug(f"Executing MongoDB command: {find_command}")
132131

133132
# Execute find command directly
134133
result = db.command(find_command)
135134

136135
# Create result set from command result
137-
self._result_set = self._result_set_class(command_result=result, query_plan=query_plan, **self._kwargs)
136+
self._result_set = self._result_set_class(
137+
command_result=result, execution_plan=execution_plan, **self._kwargs
138+
)
138139

139-
_logger.info(f"Query executed successfully on collection '{query_plan.collection}'")
140+
_logger.info(f"Query executed successfully on collection '{execution_plan.collection}'")
140141

141142
except PyMongoError as e:
142143
_logger.error(f"MongoDB command execution failed: {e}")
@@ -161,11 +162,11 @@ def execute(self: _T, operation: str, parameters: Optional[Dict[str, Any]] = Non
161162
_logger.warning("Parameter substitution not yet implemented, ignoring parameters")
162163

163164
try:
164-
# Parse SQL to QueryPlan
165-
self._current_query_plan = self._parse_sql(operation)
165+
# Parse SQL to ExecutionPlan
166+
self._current_execution_plan = self._parse_sql(operation)
166167

167-
# Execute the query plan
168-
self._execute_query_plan(self._current_query_plan)
168+
# Execute the execution plan
169+
self._execute_execution_plan(self._current_execution_plan)
169170

170171
return self
171172

pymongosql/result_set.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .common import CursorIterator
99
from .error import DatabaseError, ProgrammingError
10-
from .sql.builder import QueryPlan
10+
from .sql.builder import ExecutionPlan
1111

1212
_logger = logging.getLogger(__name__)
1313

@@ -19,7 +19,7 @@ def __init__(
1919
self,
2020
command_result: Optional[Dict[str, Any]] = None,
2121
mongo_cursor: Optional[MongoCursor] = None,
22-
query_plan: QueryPlan = None,
22+
execution_plan: ExecutionPlan = None,
2323
arraysize: int = None,
2424
**kwargs,
2525
) -> None:
@@ -32,7 +32,7 @@ def __init__(
3232
# Extract cursor info from command result
3333
self._result_cursor = command_result.get("cursor", {})
3434
self._raw_results = self._result_cursor.get("firstBatch", [])
35-
self._cached_results: List[Dict[str, Any]] = [] # Will be populated after query_plan is set
35+
self._cached_results: List[Dict[str, Any]] = []
3636
elif mongo_cursor is not None:
3737
self._mongo_cursor = mongo_cursor
3838
self._command_result = None
@@ -41,33 +41,33 @@ def __init__(
4141
else:
4242
raise ProgrammingError("Either command_result or mongo_cursor must be provided")
4343

44-
self._query_plan = query_plan
44+
self._execution_plan = execution_plan
4545
self._is_closed = False
4646
self._cache_exhausted = False
4747
self._total_fetched = 0
4848
self._description: Optional[List[Tuple[str, str, None, None, None, None, None]]] = None
4949
self._errors: List[Dict[str, str]] = []
5050

51-
# Apply projection mapping for command results now that query_plan is set
51+
# Apply projection mapping for command results now that execution_plan is set
5252
if command_result is not None and self._raw_results:
5353
self._cached_results = [self._process_document(doc) for doc in self._raw_results]
5454

5555
# Build description from projection
5656
self._build_description()
5757

5858
def _build_description(self) -> None:
59-
"""Build column description from query plan projection"""
60-
if not self._query_plan.projection_stage:
59+
"""Build column description from execution plan projection"""
60+
if not self._execution_plan.projection_stage:
6161
# No projection specified, description will be built dynamically
6262
self._description = None
6363
return
6464

65-
# Build description from projection
65+
# Build description from projection (now in MongoDB format {field: 1})
6666
description = []
67-
for field_name, alias in self._query_plan.projection_stage.items():
67+
for field_name, include_flag in self._execution_plan.projection_stage.items():
6868
# SQL cursor description format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
69-
column_name = alias if alias != field_name else field_name
70-
description.append((column_name, "VARCHAR", None, None, None, None, None))
69+
if include_flag == 1: # Field is included in projection
70+
description.append((field_name, "VARCHAR", None, None, None, None, None))
7171

7272
self._description = description
7373

@@ -111,20 +111,19 @@ def _ensure_results_available(self, count: int = 1) -> None:
111111

112112
def _process_document(self, doc: Dict[str, Any]) -> Dict[str, Any]:
113113
"""Process a MongoDB document according to projection mapping"""
114-
if not self._query_plan.projection_stage:
114+
if not self._execution_plan.projection_stage:
115115
# No projection, return document as-is (including _id)
116116
return dict(doc)
117117

118-
# Apply projection mapping
118+
# Apply projection mapping (now using MongoDB format {field: 1})
119119
processed = {}
120-
for field_name, alias in self._query_plan.projection_stage.items():
121-
if field_name in doc:
122-
output_name = alias if alias != field_name else field_name
123-
processed[output_name] = doc[field_name]
124-
elif field_name != "_id": # _id might be excluded by MongoDB
125-
# Field not found, set to None
126-
output_name = alias if alias != field_name else field_name
127-
processed[output_name] = None
120+
for field_name, include_flag in self._execution_plan.projection_stage.items():
121+
if include_flag == 1: # Field is included in projection
122+
if field_name in doc:
123+
processed[field_name] = doc[field_name]
124+
elif field_name != "_id": # _id might be excluded by MongoDB
125+
# Field not found, set to None
126+
processed[field_name] = None
128127

129128
return processed
130129

pymongosql/sql/ast.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict
44

55
from ..error import SqlSyntaxError
6-
from .builder import QueryPlan
6+
from .builder import ExecutionPlan
77
from .handler import BaseHandler, HandlerFactory, ParseResult
88
from .partiql.PartiQLLexer import PartiQLLexer
99
from .partiql.PartiQLParser import PartiQLParser
@@ -46,9 +46,9 @@ def parse_result(self) -> ParseResult:
4646
"""Get the current parse result"""
4747
return self._parse_result
4848

49-
def parse_to_query_plan(self) -> QueryPlan:
50-
"""Convert the parse result to a QueryPlan"""
51-
return QueryPlan(
49+
def parse_to_execution_plan(self) -> ExecutionPlan:
50+
"""Convert the parse result to an ExecutionPlan"""
51+
return ExecutionPlan(
5252
collection=self._parse_result.collection,
5353
filter_stage=self._parse_result.filter_conditions,
5454
projection_stage=self._parse_result.projection,
@@ -114,3 +114,43 @@ def visitWhereClauseSelect(self, ctx: PartiQLParser.WhereClauseSelectContext) ->
114114
except Exception as e:
115115
_logger.warning(f"Error processing WHERE clause: {e}")
116116
return self.visitChildren(ctx)
117+
118+
def visitOrderByClause(self, ctx: PartiQLParser.OrderByClauseContext) -> Any:
119+
"""Handle ORDER BY clause for sorting"""
120+
_logger.debug("Processing ORDER BY clause")
121+
122+
try:
123+
sort_specs = []
124+
if hasattr(ctx, "orderSortSpec") and ctx.orderSortSpec():
125+
for sort_spec in ctx.orderSortSpec():
126+
field_name = sort_spec.expr().getText() if sort_spec.expr() else "_id"
127+
# Check for ASC/DESC (default is ASC = 1)
128+
direction = 1 # ASC
129+
if hasattr(sort_spec, "DESC") and sort_spec.DESC():
130+
direction = -1 # DESC
131+
# Convert to the expected format: List[Dict[str, int]]
132+
sort_specs.append({field_name: direction})
133+
134+
self._parse_result.sort_fields = sort_specs
135+
_logger.debug(f"Extracted sort specifications: {sort_specs}")
136+
return self.visitChildren(ctx)
137+
except Exception as e:
138+
_logger.warning(f"Error processing ORDER BY clause: {e}")
139+
return self.visitChildren(ctx)
140+
141+
def visitLimitClause(self, ctx: PartiQLParser.LimitClauseContext) -> Any:
142+
"""Handle LIMIT clause for result limiting"""
143+
_logger.debug("Processing LIMIT clause")
144+
try:
145+
if hasattr(ctx, "exprSelect") and ctx.exprSelect():
146+
limit_text = ctx.exprSelect().getText()
147+
try:
148+
limit_value = int(limit_text)
149+
self._parse_result.limit_value = limit_value
150+
_logger.debug(f"Extracted limit value: {limit_value}")
151+
except ValueError as e:
152+
_logger.warning(f"Invalid LIMIT value '{limit_text}': {e}")
153+
return self.visitChildren(ctx)
154+
except Exception as e:
155+
_logger.warning(f"Error processing LIMIT clause: {e}")
156+
return self.visitChildren(ctx)

0 commit comments

Comments
 (0)