Skip to content

Commit fe7559d

Browse files
author
Peng Ren
committed
Fix bugs
1 parent 23b5d02 commit fe7559d

File tree

5 files changed

+120
-43
lines changed

5 files changed

+120
-43
lines changed

pymongosql/__init__.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if TYPE_CHECKING:
77
from .connection import Connection
88

9-
__version__: str = "0.4.3"
9+
__version__: str = "0.4.4"
1010

1111
# Globals https://www.python.org/dev/peps/pep-0249/#globals
1212
apilevel: str = "2.0"
@@ -36,6 +36,115 @@ def __hash__(self):
3636
return frozenset.__hash__(self)
3737

3838

39+
# DB API 2.0 Type Objects for MongoDB Data Types
40+
# https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors
41+
# Mapping of MongoDB BSON types to DB API 2.0 type objects
42+
43+
# Null/None type
44+
NULL = DBAPITypeObject(("null", "Null", "NULL"))
45+
46+
# String types
47+
STRING = DBAPITypeObject(("string", "str", "String", "VARCHAR", "CHAR", "TEXT"))
48+
49+
# Numeric types - Integer
50+
BINARY = DBAPITypeObject(("binary", "Binary", "BINARY", "VARBINARY", "BLOB", "ObjectId"))
51+
52+
# Numeric types - Integer
53+
NUMBER = DBAPITypeObject(("int", "integer", "long", "int32", "int64", "Integer", "BIGINT", "INT"))
54+
55+
# Numeric types - Decimal/Float
56+
FLOAT = DBAPITypeObject(("double", "decimal", "float", "Double", "DECIMAL", "FLOAT", "NUMERIC"))
57+
58+
# Boolean type
59+
BOOLEAN = DBAPITypeObject(("bool", "boolean", "Bool", "BOOLEAN"))
60+
61+
# Date/Time types
62+
DATE = DBAPITypeObject(("date", "Date", "DATE"))
63+
TIME = DBAPITypeObject(("time", "Time", "TIME"))
64+
DATETIME = DBAPITypeObject(("datetime", "timestamp", "Timestamp", "DATETIME", "TIMESTAMP"))
65+
66+
# Aggregate types
67+
ARRAY = DBAPITypeObject(("array", "Array", "ARRAY", "list"))
68+
OBJECT = DBAPITypeObject(("object", "Object", "OBJECT", "struct", "dict", "document"))
69+
70+
# Special MongoDB types
71+
OBJECTID = DBAPITypeObject(("objectid", "ObjectId", "OBJECTID", "oid"))
72+
REGEX = DBAPITypeObject(("regex", "Regex", "REGEX", "regexp"))
73+
74+
# Map MongoDB BSON type codes to DB API type objects
75+
# This mapping helps cursor.description identify the correct type for each column
76+
_MONGODB_TYPE_MAP = {
77+
"null": NULL,
78+
"string": STRING,
79+
"int": NUMBER,
80+
"integer": NUMBER,
81+
"long": NUMBER,
82+
"int32": NUMBER,
83+
"int64": NUMBER,
84+
"double": FLOAT,
85+
"decimal": FLOAT,
86+
"float": FLOAT,
87+
"bool": BOOLEAN,
88+
"boolean": BOOLEAN,
89+
"date": DATE,
90+
"datetime": DATETIME,
91+
"timestamp": DATETIME,
92+
"array": ARRAY,
93+
"object": OBJECT,
94+
"document": OBJECT,
95+
"bson.objectid": OBJECTID,
96+
"objectid": OBJECTID,
97+
"regex": REGEX,
98+
"binary": BINARY,
99+
}
100+
101+
102+
def get_type_code(value: object) -> str:
103+
"""Get the type code for a MongoDB value.
104+
105+
Maps a MongoDB/Python value to its corresponding DB API type code string.
106+
107+
Args:
108+
value: The value to determine the type for
109+
110+
Returns:
111+
A string representing the DB API type code
112+
"""
113+
if value is None:
114+
return "null"
115+
elif isinstance(value, bool):
116+
return "bool"
117+
elif isinstance(value, int):
118+
return "int"
119+
elif isinstance(value, float):
120+
return "double"
121+
elif isinstance(value, str):
122+
return "string"
123+
elif isinstance(value, bytes):
124+
return "binary"
125+
elif isinstance(value, dict):
126+
return "object"
127+
elif isinstance(value, list):
128+
return "array"
129+
elif hasattr(value, "__class__") and value.__class__.__name__ == "ObjectId":
130+
return "objectid"
131+
else:
132+
return "object"
133+
134+
135+
def get_type_object(value: object) -> DBAPITypeObject:
136+
"""Get the DB API type object for a MongoDB value.
137+
138+
Args:
139+
value: The value to get type information for
140+
141+
Returns:
142+
A DBAPITypeObject representing the value's type
143+
"""
144+
type_code = get_type_code(value)
145+
return _MONGODB_TYPE_MAP.get(type_code, OBJECT)
146+
147+
39148
def connect(*args, **kwargs) -> "Connection":
40149
from .connection import Connection
41150

pymongosql/result_set.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import jmespath
77
from pymongo.errors import PyMongoError
88

9+
from . import STRING
910
from .common import CursorIterator
1011
from .error import DatabaseError, ProgrammingError
1112
from .sql.query_builder import QueryExecutionPlan
@@ -69,7 +70,9 @@ def _build_description(self) -> None:
6970
if not self._execution_plan.projection_stage:
7071
# No projection specified, build description from column names if available
7172
if self._column_names:
72-
self._description = [(col_name, str, None, None, None, None, None) for col_name in self._column_names]
73+
self._description = [
74+
(col_name, STRING, None, None, None, None, None) for col_name in self._column_names
75+
]
7376
else:
7477
# Will be built dynamically when columns are established
7578
self._description = None
@@ -84,7 +87,7 @@ def _build_description(self) -> None:
8487
if include_flag == 1: # Field is included in projection
8588
# Use alias if available, otherwise use field name
8689
display_name = column_aliases.get(field_name, field_name)
87-
description.append((display_name, str, None, None, None, None, None))
90+
description.append((display_name, STRING, None, None, None, None, None))
8891

8992
self._description = description
9093

@@ -226,7 +229,7 @@ def description(
226229
if self._column_names:
227230
# Build description from established column names
228231
self._description = [
229-
(col_name, str, None, None, None, None, None) for col_name in self._column_names
232+
(col_name, STRING, None, None, None, None, None) for col_name in self._column_names
230233
]
231234
except Exception as e:
232235
_logger.warning(f"Could not build dynamic description: {e}")

pymongosql/sql/query_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def sort(self, specs: List[Dict[str, int]]) -> "MongoQueryBuilder":
156156
def limit(self, count: int) -> "MongoQueryBuilder":
157157
"""Set limit for results"""
158158
if not isinstance(count, int) or count < 0:
159-
self._add_error("Limit must be a non-negative integer")
160159
return self
161160

162161
self._execution_plan.limit_stage = count
@@ -166,7 +165,6 @@ def limit(self, count: int) -> "MongoQueryBuilder":
166165
def skip(self, count: int) -> "MongoQueryBuilder":
167166
"""Set skip count for pagination"""
168167
if not isinstance(count, int) or count < 0:
169-
self._add_error("Skip must be a non-negative integer")
170168
return self
171169

172170
self._execution_plan.skip_stage = count

tests/test_cursor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from bson.timestamp import Timestamp
66

7+
from pymongosql import STRING
78
from pymongosql.error import DatabaseError, ProgrammingError, SqlSyntaxError
89
from pymongosql.result_set import ResultSet
910

@@ -271,8 +272,8 @@ def test_description_type_and_shape(self, conn):
271272
desc = cursor.description
272273
assert isinstance(desc, list)
273274
assert all(isinstance(d, tuple) and len(d) == 7 and isinstance(d[0], str) for d in desc)
274-
# type_code should be a type object (e.g., str) or None when unknown
275-
assert all((isinstance(d[1], type) or d[1] is None) for d in desc)
275+
# type_code should be a DBAPITypeObject (e.g., STRING) or None when unknown
276+
assert all((d[1] == STRING or d[1] is None) for d in desc)
276277

277278
def test_description_projection(self, conn):
278279
"""Ensure projection via SQL reflects in the description names and types"""
@@ -285,7 +286,7 @@ def test_description_projection(self, conn):
285286
assert "email" in col_names
286287
for d in desc:
287288
if d[0] in ("name", "email"):
288-
assert isinstance(d[1], type) or d[1] is None
289+
assert d[1] == STRING or d[1] is None
289290

290291
def test_cursor_pagination_fetchmany_triggers_getmore(self, conn, monkeypatch):
291292
"""Test that cursor.fetchmany triggers getMore when executing SQL that yields a paginated cursor

tests/test_query_builder.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -221,47 +221,13 @@ def test_limit_valid(self):
221221

222222
assert builder._execution_plan.limit_stage == 100
223223

224-
def test_limit_negative(self):
225-
"""Test limit with negative value adds error."""
226-
builder = MongoQueryBuilder()
227-
builder.limit(-10)
228-
229-
errors = builder.get_errors()
230-
assert len(errors) > 0
231-
assert "non-negative" in errors[0].lower()
232-
233-
def test_limit_non_integer(self):
234-
"""Test limit with non-integer adds error."""
235-
builder = MongoQueryBuilder()
236-
builder.limit(10.5)
237-
238-
errors = builder.get_errors()
239-
assert len(errors) > 0
240-
241224
def test_skip_valid(self):
242225
"""Test skip with valid value."""
243226
builder = MongoQueryBuilder()
244227
builder.skip(50)
245228

246229
assert builder._execution_plan.skip_stage == 50
247230

248-
def test_skip_negative(self):
249-
"""Test skip with negative value adds error."""
250-
builder = MongoQueryBuilder()
251-
builder.skip(-5)
252-
253-
errors = builder.get_errors()
254-
assert len(errors) > 0
255-
assert "non-negative" in errors[0].lower()
256-
257-
def test_skip_non_integer(self):
258-
"""Test skip with non-integer adds error."""
259-
builder = MongoQueryBuilder()
260-
builder.skip("10")
261-
262-
errors = builder.get_errors()
263-
assert len(errors) > 0
264-
265231
def test_column_aliases_valid(self):
266232
"""Test column_aliases with valid dict."""
267233
builder = MongoQueryBuilder()

0 commit comments

Comments
 (0)