Skip to content

Commit f0738f4

Browse files
passrenPeng Ren
andauthored
Fix bugs for superset query (#5)
Co-authored-by: Peng Ren <ia250@cummins.com>
1 parent 4c4e7ed commit f0738f4

File tree

7 files changed

+467
-45
lines changed

7 files changed

+467
-45
lines changed

pymongosql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if TYPE_CHECKING:
77
from .connection import Connection
88

9-
__version__: str = "0.2.2"
9+
__version__: str = "0.2.3"
1010

1111
# Globals https://www.python.org/dev/peps/pep-0249/#globals
1212
apilevel: str = "2.0"

pymongosql/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,18 @@ def __init__(
4545
"""
4646
# Check if connection string specifies mode
4747
connection_string = host if isinstance(host, str) else None
48-
self._mode, host = ConnectionHelper.parse_connection_string(connection_string)
48+
mode, host = ConnectionHelper.parse_connection_string(connection_string)
49+
50+
self._mode = kwargs.pop("mode", None)
51+
if not self._mode and mode:
52+
self._mode = mode
4953

5054
# Extract commonly used parameters for backward compatibility
5155
self._host = host or "localhost"
5256
self._port = port or 27017
5357

5458
# Handle database parameter separately (not a MongoClient parameter)
55-
self._database_name = kwargs.pop("database", None) # Remove from kwargs
59+
self._database_name = kwargs.pop("database", None)
5660

5761
# Store all PyMongo parameters to pass through directly
5862
self._pymongo_params = kwargs.copy()

pymongosql/result_set.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,16 @@ def _process_and_cache_batch(self, batch: List[Dict[str, Any]]) -> None:
6565
self._total_fetched += len(batch)
6666

6767
def _build_description(self) -> None:
68-
"""Build column description from execution plan projection"""
68+
"""Build column description from execution plan projection or established column names"""
6969
if not self._execution_plan.projection_stage:
70-
# No projection specified, description will be built dynamically
71-
self._description = None
70+
# No projection specified, build description from column names if available
71+
if self._column_names:
72+
self._description = [
73+
(col_name, "VARCHAR", None, None, None, None, None) for col_name in self._column_names
74+
]
75+
else:
76+
# Will be built dynamically when columns are established
77+
self._description = None
7278
return
7379

7480
# Build description from projection (now in MongoDB format {field: 1})
@@ -198,10 +204,13 @@ def description(
198204
self,
199205
) -> Optional[List[Tuple[str, str, None, None, None, None, None]]]:
200206
"""Return column description"""
201-
if self._description is None and not self._cache_exhausted:
202-
# Try to fetch one result to build description dynamically
207+
if self._description is None:
208+
# Try to build description from established column names
203209
try:
204-
self._ensure_results_available(1)
210+
if not self._cache_exhausted:
211+
# Fetch one result to establish column names if needed
212+
self._ensure_results_available(1)
213+
205214
if self._column_names:
206215
# Build description from established column names
207216
self._description = [

pymongosql/sqlalchemy_mongodb/__init__.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,31 @@
2929
__supports_sqlalchemy_2x__ = False
3030

3131

32-
def create_engine_url(host: str = "localhost", port: int = 27017, database: str = "test", **kwargs) -> str:
32+
def create_engine_url(
33+
host: str = "localhost", port: int = 27017, database: str = "test", mode: str = "standard", **kwargs
34+
) -> str:
3335
"""Create a SQLAlchemy engine URL for PyMongoSQL.
3436
3537
Args:
3638
host: MongoDB host
3739
port: MongoDB port
3840
database: Database name
41+
mode: Connection mode - "standard" (default) or "superset" (with subquery support)
3942
**kwargs: Additional connection parameters
4043
4144
Returns:
42-
SQLAlchemy URL string (uses mongodb:// format)
45+
SQLAlchemy URL string
4346
4447
Example:
48+
>>> # Standard mode
4549
>>> url = create_engine_url("localhost", 27017, "mydb")
4650
>>> engine = sqlalchemy.create_engine(url)
51+
>>> # Superset mode with subquery support
52+
>>> url = create_engine_url("localhost", 27017, "mydb", mode="superset")
53+
>>> engine = sqlalchemy.create_engine(url)
4754
"""
55+
scheme = "mongodb+superset" if mode == "superset" else "mongodb"
56+
4857
params = []
4958
for key, value in kwargs.items():
5059
params.append(f"{key}={value}")
@@ -53,7 +62,7 @@ def create_engine_url(host: str = "localhost", port: int = 27017, database: str
5362
if param_str:
5463
param_str = "?" + param_str
5564

56-
return f"mongodb://{host}:{port}/{database}{param_str}"
65+
return f"{scheme}://{host}:{port}/{database}{param_str}"
5766

5867

5968
def create_mongodb_url(mongodb_uri: str) -> str:
@@ -77,11 +86,11 @@ def create_mongodb_url(mongodb_uri: str) -> str:
7786
def create_engine_from_mongodb_uri(mongodb_uri: str, **engine_kwargs):
7887
"""Create a SQLAlchemy engine from any MongoDB connection string.
7988
80-
This function handles both mongodb:// and mongodb+srv:// URIs properly.
81-
Use this instead of create_engine() directly for mongodb+srv URIs.
89+
This function handles mongodb://, mongodb+srv://, and mongodb+superset:// URIs properly.
90+
Use this instead of create_engine() directly for special URI schemes.
8291
8392
Args:
84-
mongodb_uri: Standard MongoDB connection string
93+
mongodb_uri: MongoDB connection string (supports standard, SRV, and superset modes)
8594
**engine_kwargs: Additional arguments passed to create_engine
8695
8796
Returns:
@@ -92,6 +101,8 @@ def create_engine_from_mongodb_uri(mongodb_uri: str, **engine_kwargs):
92101
>>> engine = create_engine_from_mongodb_uri("mongodb+srv://user:pass@cluster.net/db")
93102
>>> # For standard MongoDB
94103
>>> engine = create_engine_from_mongodb_uri("mongodb://localhost:27017/mydb")
104+
>>> # For superset mode (with subquery support)
105+
>>> engine = create_engine_from_mongodb_uri("mongodb+superset://localhost:27017/mydb")
95106
"""
96107
try:
97108
from sqlalchemy import create_engine
@@ -109,6 +120,22 @@ def custom_create_connect_args(url):
109120
opts = {"host": mongodb_uri}
110121
return [], opts
111122

123+
engine.dialect.create_connect_args = custom_create_connect_args
124+
return engine
125+
elif mongodb_uri.startswith("mongodb+superset://"):
126+
# For MongoDB+Superset, convert to standard mongodb:// for SQLAlchemy compatibility
127+
# but preserve the superset mode by passing it through connection options
128+
converted_uri = mongodb_uri.replace("mongodb+superset://", "mongodb://")
129+
130+
# Create engine with converted URI
131+
engine = create_engine(converted_uri, **engine_kwargs)
132+
133+
def custom_create_connect_args(url):
134+
# Use original superset URI for actual MongoDB connection
135+
# This preserves the superset mode for subquery support
136+
opts = {"host": mongodb_uri}
137+
return [], opts
138+
112139
engine.dialect.create_connect_args = custom_create_connect_args
113140
return engine
114141
else:
@@ -123,18 +150,18 @@ def register_dialect():
123150
"""Register the PyMongoSQL dialect with SQLAlchemy.
124151
125152
This function handles registration for both SQLAlchemy 1.x and 2.x.
126-
Registers support for standard MongoDB connection strings only.
153+
Registers support for standard, SRV, and superset MongoDB connection strings.
127154
"""
128155
try:
129156
from sqlalchemy.dialects import registry
130157

131158
# Register for standard MongoDB URLs
132159
registry.register("mongodb", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect")
133160

134-
# Try to register both SRV forms so SQLAlchemy can resolve SRV-style URLs
135-
# (either 'mongodb+srv' or the dotted 'mongodb.srv' plugin name).
136-
# Some SQLAlchemy versions accept '+' in scheme names; others import
137-
# the dotted plugin name. Attempt both registrations in one block.
161+
# Try to register SRV and Superset forms so SQLAlchemy can resolve these URL patterns
162+
# (either with '+' or dotted notation for compatibility with different SQLAlchemy versions).
163+
# Some SQLAlchemy versions accept '+' in scheme names; others import the dotted plugin name.
164+
# Attempt all registrations but don't fail if some are not supported.
138165
try:
139166
registry.register("mongodb+srv", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect")
140167
registry.register("mongodb.srv", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect")
@@ -143,6 +170,18 @@ def register_dialect():
143170
# create_engine_from_mongodb_uri by converting 'mongodb+srv' to 'mongodb'.
144171
pass
145172

173+
try:
174+
registry.register(
175+
"mongodb+superset", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect"
176+
)
177+
registry.register(
178+
"mongodb.superset", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect"
179+
)
180+
except Exception:
181+
# If registration fails we fall back to handling Superset URIs in
182+
# create_engine_from_mongodb_uri by converting 'mongodb+superset' to 'mongodb'.
183+
pass
184+
146185
return True
147186
except ImportError:
148187
# Fallback for versions without registry

pymongosql/superset_mongodb/detector.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,58 @@ def extract_outer_query(cls, query: str) -> Optional[Tuple[str, str]]:
8989
"""
9090
Extract outer query with subquery placeholder.
9191
92+
Preserves the complete outer query structure while replacing the subquery
93+
with a reference to the temporary table.
94+
9295
Returns:
93-
Tuple of (outer_query, subquery_alias) or None
96+
Tuple of (outer_query, subquery_alias) or None if not a wrapped subquery
9497
"""
9598
info = cls.detect(query)
9699
if not info.is_wrapped:
97100
return None
98101

99-
# Replace subquery with temporary table reference
100-
outer = cls.WRAPPED_SUBQUERY_PATTERN.sub(
101-
f"SELECT * FROM {info.subquery_alias}",
102-
query,
102+
# Pattern to capture: SELECT <columns> FROM ( <subquery> ) AS <alias> <rest>
103+
# Matches both SELECT col1, col2 and SELECT col1 AS alias1, col2 AS alias2 formats
104+
pattern = re.compile(
105+
r"(SELECT\s+.+?)\s+FROM\s*\(\s*(?:select|SELECT)\s+.+?\s*\)\s+(?:AS\s+)?(\w+)(.*)",
106+
re.IGNORECASE | re.DOTALL,
103107
)
104108

105-
return outer, info.subquery_alias
109+
match = pattern.search(query)
110+
if match:
111+
select_clause = match.group(1).strip()
112+
table_alias = match.group(2)
113+
rest_of_query = match.group(3).strip()
114+
115+
if rest_of_query:
116+
outer = f"{select_clause} FROM {table_alias} {rest_of_query}"
117+
else:
118+
outer = f"{select_clause} FROM {table_alias}"
119+
120+
return outer, table_alias
121+
122+
# If pattern doesn't match exactly, fall back to preserving SELECT clause
123+
# Extract from SELECT to FROM keyword
124+
select_match = re.search(r"(SELECT\s+.+?)\s+FROM", query, re.IGNORECASE | re.DOTALL)
125+
if not select_match:
126+
return None
127+
128+
select_clause = select_match.group(1).strip()
129+
130+
# Extract table alias and rest of query after the closing paren
131+
rest_match = re.search(r"\)\s+(?:AS\s+)?(\w+)(.*)", query, re.IGNORECASE | re.DOTALL)
132+
if rest_match:
133+
table_alias = rest_match.group(1)
134+
rest_of_query = rest_match.group(2).strip()
135+
136+
if rest_of_query:
137+
outer = f"{select_clause} FROM {table_alias} {rest_of_query}"
138+
else:
139+
outer = f"{select_clause} FROM {table_alias}"
140+
141+
return outer, table_alias
142+
143+
return None
106144

107145
@classmethod
108146
def is_simple_select(cls, query: str) -> bool:

pymongosql/superset_mongodb/executor.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,34 +105,57 @@ def execute(
105105
try:
106106
# Create temporary table with MongoDB results
107107
querydb_query, table_name = SubqueryDetector.extract_outer_query(context.query)
108+
if querydb_query is None or table_name is None:
109+
# Fallback to original query if extraction fails
110+
querydb_query = context.query
111+
table_name = "virtual_table"
112+
108113
query_db.insert_records(table_name, mongo_dicts)
109114

110115
# Execute outer query against intermediate DB
111-
_logger.debug(f"Stage 2: Executing {db_name} query: {querydb_query}")
116+
_logger.debug(f"Stage 2: Executing QueryDBSQLite query: {querydb_query}")
112117

113118
querydb_rows = query_db.execute_query(querydb_query)
114119
_logger.debug(f"Stage 2 complete: Got {len(querydb_rows)} rows from {db_name}")
115120

116121
# Create a ResultSet-like object from intermediate DB results
117122
result_set = self._create_result_set_from_db(querydb_rows, querydb_query)
118123

119-
self._execution_plan = ExecutionPlan(collection="query_db_result", projection_stage={})
124+
# Build projection_stage from query database result columns
125+
projection_stage = {}
126+
if querydb_rows and isinstance(querydb_rows[0], dict):
127+
# Extract column names from first result row
128+
for col_name in querydb_rows[0].keys():
129+
projection_stage[col_name] = 1 # 1 means included in projection
130+
else:
131+
# If no rows, get column names from the SQLite query directly
132+
try:
133+
cursor = query_db.execute_query_cursor(querydb_query)
134+
if cursor.description:
135+
# Extract column names from cursor description
136+
for col_desc in cursor.description:
137+
col_name = col_desc[0]
138+
projection_stage[col_name] = 1
139+
except Exception as e:
140+
_logger.warning(f"Could not extract column names from empty result: {e}")
141+
142+
self._execution_plan = ExecutionPlan(collection="query_db_result", projection_stage=projection_stage)
120143

121144
return result_set
122145

123146
finally:
124147
query_db.close()
125148

126-
def _create_result_set_from_db(self, rows: List[Dict[str, Any]], query: str) -> ResultSet:
149+
def _create_result_set_from_db(self, rows: List[Dict[str, Any]], query: str) -> Dict[str, Any]:
127150
"""
128-
Create a ResultSet from query database results.
151+
Create a command result from query database results.
129152
130153
Args:
131154
rows: List of dictionaries from query database
132155
query: Original SQL query
133156
134157
Returns:
135-
ResultSet with query database results
158+
Dictionary with command result format
136159
"""
137160
# Create a mock command result structure compatible with ResultSet
138161
command_result = {

0 commit comments

Comments
 (0)