diff --git a/row_query/__init__.py b/row_query/__init__.py index 47ad4ad..42fb129 100644 --- a/row_query/__init__.py +++ b/row_query/__init__.py @@ -26,12 +26,14 @@ QueryNotFoundError, RegistryError, RowQueryError, + SQLSanitizationError, StrictModeViolation, TransactionError, TransactionStateError, ) from row_query.core.migration import MigrationInfo, MigrationManager from row_query.core.registry import SQLRegistry +from row_query.core.sanitizer import SQLSanitizer from row_query.core.transaction import AsyncTransactionManager, TransactionManager from row_query.mapping.model import ModelMapper @@ -45,6 +47,8 @@ "AsyncEngine", # Registry "SQLRegistry", + # Sanitizer + "SQLSanitizer", # Transaction "TransactionManager", "AsyncTransactionManager", @@ -63,6 +67,7 @@ "ExecutionError", "MultipleRowsError", "ParameterBindingError", + "SQLSanitizationError", "MappingError", "ColumnMismatchError", "StrictModeViolation", diff --git a/row_query/adapters/sqlite.py b/row_query/adapters/sqlite.py index b64f975..4855798 100644 --- a/row_query/adapters/sqlite.py +++ b/row_query/adapters/sqlite.py @@ -50,7 +50,7 @@ def execute( params: dict[str, Any] | None = None, ) -> sqlite3.Cursor: """Execute SQL and return a cursor.""" - return connection.execute(sql, params or {}) + return connection.execute(sql, params if params is not None else {}) class SqliteAsyncAdapter: @@ -95,4 +95,4 @@ async def execute_async( params: dict[str, Any] | None = None, ) -> Any: """Execute SQL asynchronously and return a cursor.""" - return await connection.execute(sql, params or {}) + return await connection.execute(sql, params if params is not None else {}) diff --git a/row_query/core/engine.py b/row_query/core/engine.py index 8ce0238..0cdb372 100644 --- a/row_query/core/engine.py +++ b/row_query/core/engine.py @@ -13,8 +13,9 @@ MultipleRowsError, ParameterBindingError, ) -from row_query.core.params import normalize_params +from row_query.core.params import coerce_params, normalize_params, resolve_sql from row_query.core.registry import SQLRegistry +from row_query.core.sanitizer import SQLSanitizer from row_query.core.transaction import AsyncTransactionManager, TransactionManager T = TypeVar("T") @@ -69,9 +70,11 @@ def __init__( self, connection_manager: ConnectionManager, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection_manager = connection_manager self._registry = registry + self._sanitizer = sanitizer self._paramstyle = connection_manager.adapter.paramstyle @classmethod @@ -79,46 +82,56 @@ def from_config( cls, config: Any, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> Engine: """Create an Engine from a ConnectionConfig and SQLRegistry. Args: config: ConnectionConfig instance registry: SQLRegistry instance + sanitizer: Optional SQLSanitizer applied to inline SQL strings. Returns: Engine instance """ connection_manager = ConnectionManager(config) - return cls(connection_manager, registry) + return cls(connection_manager, registry, sanitizer) def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: """Fetch a single row. + *query* may be a registry key (e.g. ``"users.get_by_id"``) or an + inline SQL string (e.g. ``"SELECT * FROM users WHERE id = ?"``). + + *params* may be a ``dict`` for named binding, a ``tuple``/``list`` for + positional binding, or a single scalar that is automatically wrapped in + a tuple. + Returns None if zero rows match. Raises MultipleRowsError if more than one row matches. """ - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) if len(rows) == 0: return None if len(rows) > 1: - raise MultipleRowsError(query_name, len(rows)) + raise MultipleRowsError(label, len(rows)) row = rows[0] if mapper is not None: @@ -127,20 +140,25 @@ def fetch_one( def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: - """Fetch all matching rows.""" - sql = self._registry.get(query_name) + """Fetch all matching rows. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) @@ -150,18 +168,23 @@ def fetch_all( def fetch_scalar( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> Any: - """Fetch a single scalar value (first column of first row).""" - sql = self._registry.get(query_name) + """Fetch a single scalar value (first column of first row). + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e row = cursor.fetchone() if row is None: @@ -177,18 +200,23 @@ def fetch_scalar( def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query. Returns affected row count.""" - sql = self._registry.get(query_name) + """Execute a write query. Returns affected row count. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e conn.commit() return int(cursor.rowcount) @@ -204,6 +232,7 @@ def transaction(self) -> TransactionManager: adapter=self._connection_manager.adapter, registry=self._registry, pool=pool, + sanitizer=self._sanitizer, ) @@ -214,9 +243,11 @@ def __init__( self, connection_manager: AsyncConnectionManager, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection_manager = connection_manager self._registry = registry + self._sanitizer = sanitizer self._paramstyle = connection_manager.adapter.paramstyle @classmethod @@ -224,35 +255,42 @@ def from_config( cls, config: Any, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> AsyncEngine: """Create an AsyncEngine from a ConnectionConfig and SQLRegistry. Args: config: ConnectionConfig instance registry: SQLRegistry instance + sanitizer: Optional SQLSanitizer applied to inline SQL strings. Returns: AsyncEngine instance """ connection_manager = AsyncConnectionManager(config) - return cls(connection_manager, registry) + return cls(connection_manager, registry, sanitizer) async def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: - """Fetch a single row asynchronously.""" - sql = self._registry.get(query_name) + """Fetch a single row asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return None @@ -268,7 +306,7 @@ async def fetch_one( if len(rows) == 0: return None if len(rows) > 1: - raise MultipleRowsError(query_name, len(rows)) + raise MultipleRowsError(label, len(rows)) row = rows[0] if mapper is not None: @@ -277,20 +315,25 @@ async def fetch_one( async def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: - """Fetch all matching rows asynchronously.""" - sql = self._registry.get(query_name) + """Fetch all matching rows asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return [] @@ -309,18 +352,23 @@ async def fetch_all( async def fetch_scalar( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> Any: - """Fetch a single scalar value asynchronously.""" - sql = self._registry.get(query_name) + """Fetch a single scalar value asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e row = await cursor.fetchone() if row is None: @@ -336,18 +384,23 @@ async def fetch_scalar( async def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query asynchronously.""" - sql = self._registry.get(query_name) + """Execute a write query asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e await conn.commit() return int(cursor.rowcount) @@ -362,4 +415,5 @@ def transaction(self) -> AsyncTransactionManager: adapter=self._connection_manager.adapter, registry=self._registry, connection_manager=self._connection_manager, + sanitizer=self._sanitizer, ) diff --git a/row_query/core/exceptions.py b/row_query/core/exceptions.py index 0101332..313bc38 100644 --- a/row_query/core/exceptions.py +++ b/row_query/core/exceptions.py @@ -60,6 +60,13 @@ def __init__(self, query_name: str, detail: str) -> None: super().__init__(f"Parameter binding error for '{query_name}': {detail}") +class SQLSanitizationError(ExecutionError): + """Raised when an inline SQL string fails a sanitization check.""" + + def __init__(self, detail: str) -> None: + super().__init__(f"SQL sanitization failed: {detail}") + + # --- Mapping --- diff --git a/row_query/core/params.py b/row_query/core/params.py index 0cacdef..a78c0f0 100644 --- a/row_query/core/params.py +++ b/row_query/core/params.py @@ -8,6 +8,11 @@ import re from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from row_query.core.registry import SQLRegistry + from row_query.core.sanitizer import SQLSanitizer # Matches :name but not ::typecast and not inside words # Negative lookbehind for : (handles ::), \w (handles mid-word colons) @@ -53,3 +58,65 @@ def _convert_to_pyformat(sql: str) -> str: parts.append(_PARAM_PATTERN.sub(r"%(\1)s", sql[last_end:])) return "".join(parts) + + +def is_raw_sql(query: str) -> bool: + """Return True if query is an inline SQL string rather than a registry key. + + Registry keys use dot-notation (e.g. ``users.get_by_id``) and never + contain whitespace. Any SQL statement will contain at least one space. + + Note: Registry keys are validated during registration to ensure they do + not contain whitespace, preventing ambiguity. + """ + return any(c.isspace() for c in query) + + +def coerce_params( + params: dict[str, Any] | tuple[Any, ...] | list[Any] | Any, +) -> dict[str, Any] | tuple[Any, ...] | None: + """Normalize *params* to a dict, tuple, or None. + + * ``None`` / ``dict`` → returned as-is (named parameter binding). + * ``tuple`` / ``list`` → converted to ``tuple`` (positional binding). + * Any other scalar → wrapped in a single-element tuple. + + Note on parameter styles: + Registry queries use `:name` style parameters (converted to driver format). + Inline SQL can use either `:name` or `?`-style placeholders depending on + the database driver. When using inline SQL with positional parameters, + ensure compatibility with your target database (SQLite uses `?`, PostgreSQL + uses `$1`, etc.). + """ + if params is None or isinstance(params, dict): + return params + if isinstance(params, (tuple, list)): + return tuple(params) + return (params,) + + +def resolve_sql( + query: str, + registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, +) -> tuple[str, str]: + """Return ``(sql_text, label)`` for *query*. + + If *query* is an inline SQL string (contains whitespace) it is returned + after optional sanitization. Otherwise it is looked up in *registry* by + name (registry queries are trusted and never sanitized). *label* is used + in error messages. + + Args: + query: Either a registry key (e.g. "users.get_by_id") or inline SQL. + registry: SQLRegistry instance for looking up named queries. + sanitizer: Optional SQLSanitizer applied only to inline SQL strings. + + Returns: + Tuple of (sql_text, label) where label is "" for inline SQL + or the registry key for named queries. + """ + if is_raw_sql(query): + sql = sanitizer.sanitize(query) if sanitizer is not None else query + return sql, "" + return registry.get(query), query diff --git a/row_query/core/registry.py b/row_query/core/registry.py index aa031f4..9e18f0c 100644 --- a/row_query/core/registry.py +++ b/row_query/core/registry.py @@ -43,6 +43,17 @@ def _load(self) -> None: parts[-1] = parts[-1].removesuffix(".sql") query_name = ".".join(parts) + # Validate that query_name doesn't contain whitespace + # This prevents ambiguity with inline SQL detection + if any(c.isspace() for c in query_name): + from row_query.core.exceptions import ExecutionError + + raise ExecutionError( + f"Registry key '{query_name}' from file '{sql_file}' contains " + f"whitespace, which is not allowed. Registry keys must not contain " + f"spaces, tabs, or newlines to avoid ambiguity with inline SQL." + ) + if query_name in self._queries: raise DuplicateQueryError( query_name, diff --git a/row_query/core/sanitizer.py b/row_query/core/sanitizer.py new file mode 100644 index 0000000..919e87d --- /dev/null +++ b/row_query/core/sanitizer.py @@ -0,0 +1,233 @@ +"""Inline SQL sanitizer. + +Only applied to raw SQL strings passed directly to engine/transaction methods. +Queries loaded from the SQLRegistry are trusted and never sanitized. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +from row_query.core.exceptions import SQLSanitizationError + +# Matches the first SQL keyword (used for verb allow-listing) +_FIRST_KEYWORD = re.compile(r"^\s*(\w+)") + + +# --------------------------------------------------------------------------- +# Internal tokenizer +# --------------------------------------------------------------------------- + + +def _tokenize(sql: str) -> list[tuple[str, str]]: + """Split *sql* into ``('string', …)``, ``('identifier', …)``, and ``('code', …)`` tokens. + + String literals (single-quoted, with ``''`` escapes) are preserved as-is. + Identifiers (double-quoted for PostgreSQL/MySQL ANSI_QUOTES, backtick-quoted + for MySQL) are also preserved to avoid stripping comment-like syntax inside them. + Everything else is a ``'code'`` token. + + Raises: + SQLSanitizationError: If an unterminated string literal or identifier is detected. + """ + tokens: list[tuple[str, str]] = [] + i = 0 + n = len(sql) + last = 0 + + while i < n: + # Single-quoted string literal + if sql[i] == "'": + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == "'": + j += 1 + if j >= n or sql[j] != "'": + break # end of literal + j += 1 # '' escape — continue + else: + j += 1 + # Check for unterminated string + if j >= n and (j == i + 1 or sql[j - 1] != "'"): + from row_query.core.exceptions import SQLSanitizationError + + raise SQLSanitizationError("Unterminated string literal detected in SQL") + tokens.append(("string", sql[i:j])) + last = j + i = j + # Double-quoted identifier (PostgreSQL, MySQL ANSI_QUOTES) + elif sql[i] == '"': + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == '"': + j += 1 + if j >= n or sql[j] != '"': + break # end of identifier + j += 1 # "" escape — continue + else: + j += 1 + # Check for unterminated identifier + if j >= n and (j == i + 1 or sql[j - 1] != '"'): + from row_query.core.exceptions import SQLSanitizationError + + raise SQLSanitizationError("Unterminated double-quoted identifier detected in SQL") + tokens.append(("identifier", sql[i:j])) + last = j + i = j + # Backtick-quoted identifier (MySQL) + elif sql[i] == "`": + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == "`": + j += 1 + if j >= n or sql[j] != "`": + break # end of identifier + j += 1 # `` escape — continue + else: + j += 1 + # Check for unterminated identifier + if j >= n and (j == i + 1 or sql[j - 1] != "`"): + from row_query.core.exceptions import SQLSanitizationError + + raise SQLSanitizationError( + "Unterminated backtick-quoted identifier detected in SQL" + ) + tokens.append(("identifier", sql[i:j])) + last = j + i = j + else: + i += 1 + + if last < n: + tokens.append(("code", sql[last:])) + + return tokens + + +def _strip_comments_in_code(code: str) -> str: + """Remove ``--`` line comments and ``/* */`` block comments from a code segment.""" + result: list[str] = [] + i = 0 + n = len(code) + + while i < n: + if code[i : i + 2] == "--": + j = code.find("\n", i) + if j == -1: + break + result.append("\n") + i = j + 1 + elif code[i : i + 2] == "/*": + j = code.find("*/", i + 2) + if j == -1: + break + result.append(" ") + i = j + 2 + else: + result.append(code[i]) + i += 1 + + return "".join(result) + + +# --------------------------------------------------------------------------- +# Individual sanitization checks +# --------------------------------------------------------------------------- + + +def _strip_comments(sql: str) -> str: + """Remove SQL comments while preserving string literals and identifiers.""" + parts: list[str] = [] + for kind, content in _tokenize(sql): + if kind in ("string", "identifier"): + parts.append(content) + else: + parts.append(_strip_comments_in_code(content)) + return "".join(parts) + + +def _check_single_statement(sql: str) -> None: + """Raise if *sql* contains a semicolon followed by non-whitespace content.""" + for kind, content in _tokenize(sql): + if kind in ("string", "identifier"): + continue + for i, ch in enumerate(content): + if ch == ";" and content[i + 1 :].strip(): + raise SQLSanitizationError( + "Multiple SQL statements are not permitted in inline SQL" + ) + + +def _check_verb(sql: str, allowed: frozenset[str]) -> None: + """Raise if the leading SQL keyword is not in *allowed*.""" + m = _FIRST_KEYWORD.match(sql) + if m: + verb = m.group(1).upper() + if verb not in allowed: + raise SQLSanitizationError( + f"SQL verb '{verb}' is not permitted; allowed: {sorted(allowed)}" + ) + + +# --------------------------------------------------------------------------- +# Public class +# --------------------------------------------------------------------------- + + +@dataclass +class SQLSanitizer: + """Configurable sanitizer for inline SQL strings. + + Applied only to raw SQL passed directly to engine/transaction methods. + Registry-loaded queries are always trusted and never sanitized. + + **IMPORTANT SECURITY WARNING:** + This sanitizer does NOT protect against SQL injection if user-provided + data is concatenated directly into SQL strings. You MUST use parameterized + queries with placeholders (e.g., `?` or `:name`) to prevent SQL injection. + The sanitizer only provides defense-in-depth measures (comment stripping, + statement blocking, verb restrictions) but is NOT a substitute for proper + parameterization. + + Example of UNSAFE code: + # NEVER DO THIS - vulnerable to SQL injection + engine.fetch_all(f"SELECT * FROM users WHERE name = '{user_input}'") + + Example of SAFE code: + # ALWAYS USE THIS - parameterized query + engine.fetch_all("SELECT * FROM users WHERE name = ?", user_input) + + Attributes: + strip_comments: Strip ``--`` and ``/* */`` comments before execution. + block_multiple_statements: Reject SQL that contains a statement- + terminating ``;`` followed by additional content (prevents query + stacking such as ``SELECT 1; DROP TABLE users``). + allowed_verbs: If not ``None``, only SQL statements whose first keyword + appears in this set are permitted. ``None`` means no restriction. + Example: ``frozenset({"SELECT", "INSERT", "UPDATE", "DELETE"})``. + """ + + strip_comments: bool = True + block_multiple_statements: bool = True + allowed_verbs: frozenset[str] | None = None + + def sanitize(self, sql: str) -> str: + """Apply all configured checks to *sql* and return the (cleaned) SQL. + + Raises: + SQLSanitizationError: If any enabled check fails. + """ + if self.strip_comments: + sql = _strip_comments(sql) + if self.block_multiple_statements: + _check_single_statement(sql) + if self.allowed_verbs is not None: + _check_verb(sql, self.allowed_verbs) + return sql diff --git a/row_query/core/transaction.py b/row_query/core/transaction.py index 94bd1d1..8ebd140 100644 --- a/row_query/core/transaction.py +++ b/row_query/core/transaction.py @@ -9,9 +9,10 @@ from enum import Enum from typing import Any -from row_query.core.exceptions import TransactionStateError -from row_query.core.params import normalize_params +from row_query.core.exceptions import ParameterBindingError, TransactionStateError +from row_query.core.params import coerce_params, normalize_params, resolve_sql from row_query.core.registry import SQLRegistry +from row_query.core.sanitizer import SQLSanitizer class _TxState(Enum): @@ -51,11 +52,13 @@ def __init__( adapter: Any, registry: SQLRegistry, pool: Any = None, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection = connection self._adapter = adapter self._registry = registry self._pool = pool + self._sanitizer = sanitizer self._paramstyle: str = adapter.paramstyle self._state = _TxState.IDLE @@ -83,26 +86,40 @@ def __exit__( def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query within this transaction.""" + """Execute a write query within this transaction. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, params) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> dict[str, Any] | None: - """Fetch a single row within transaction context.""" + """Fetch a single row within transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, params) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) if not rows: return None @@ -110,14 +127,21 @@ def fetch_one( def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> list[dict[str, Any]]: - """Fetch all rows within transaction context.""" + """Fetch all rows within transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, params) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return _rows_to_dicts(cursor) def commit(self) -> None: @@ -155,11 +179,13 @@ def __init__( connection_manager: Any, connection: Any = None, pool: Any = None, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection = connection self._adapter = adapter self._registry = registry self._pool = pool + self._sanitizer = sanitizer self._connection_manager = connection_manager self._paramstyle: str = adapter.paramstyle self._state = _TxState.IDLE @@ -194,26 +220,48 @@ async def __aexit__( async def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query within this async transaction.""" + """Execute a write query within this async transaction. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, params) + try: + cursor = await self._adapter.execute_async( + self._connection, + sql, + coerce_params(params), + ) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) async def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> dict[str, Any] | None: - """Fetch a single row within async transaction context.""" + """Fetch a single row within async transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, params) + try: + cursor = await self._adapter.execute_async( + self._connection, + sql, + coerce_params(params), + ) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return None columns = [desc[0] for desc in cursor.description] @@ -231,14 +279,25 @@ async def fetch_one( async def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> list[dict[str, Any]]: - """Fetch all rows within async transaction context.""" + """Fetch all rows within async transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, params) + try: + cursor = await self._adapter.execute_async( + self._connection, + sql, + coerce_params(params), + ) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return [] columns = [desc[0] for desc in cursor.description] diff --git a/tests/unit/test_sanitizer.py b/tests/unit/test_sanitizer.py new file mode 100644 index 0000000..aa2087d --- /dev/null +++ b/tests/unit/test_sanitizer.py @@ -0,0 +1,345 @@ +"""Unit tests for SQLSanitizer and related helpers.""" + +from __future__ import annotations + +import pytest + +from row_query.core.exceptions import SQLSanitizationError +from row_query.core.params import coerce_params, is_raw_sql +from row_query.core.sanitizer import ( + SQLSanitizer, + _check_single_statement, + _check_verb, + _strip_comments, +) + +# --------------------------------------------------------------------------- +# is_raw_sql +# --------------------------------------------------------------------------- + + +class TestIsRawSql: + def test_sql_with_space_is_raw(self) -> None: + assert is_raw_sql("SELECT 1") is True + + def test_sql_with_newline_is_raw(self) -> None: + assert is_raw_sql("SELECT\n1") is True + + def test_sql_with_tab_is_raw(self) -> None: + assert is_raw_sql("SELECT\t1") is True + + def test_registry_key_is_not_raw(self) -> None: + assert is_raw_sql("users.get_by_id") is False + + def test_dotted_registry_key_is_not_raw(self) -> None: + assert is_raw_sql("billing.invoice.list") is False + + def test_empty_string_is_not_raw(self) -> None: + assert is_raw_sql("") is False + + +# --------------------------------------------------------------------------- +# coerce_params +# --------------------------------------------------------------------------- + + +class TestCoerceParams: + def test_none_passthrough(self) -> None: + assert coerce_params(None) is None + + def test_dict_passthrough(self) -> None: + p = {"id": 1} + assert coerce_params(p) is p + + def test_tuple_passthrough(self) -> None: + p = (1, 2) + assert coerce_params(p) == (1, 2) + + def test_list_converted_to_tuple(self) -> None: + assert coerce_params([1, 2]) == (1, 2) + + def test_int_scalar_wrapped(self) -> None: + assert coerce_params(42) == (42,) + + def test_str_scalar_wrapped(self) -> None: + assert coerce_params("hello") == ("hello",) + + def test_bool_scalar_wrapped(self) -> None: + assert coerce_params(True) == (True,) + + def test_empty_tuple_passthrough(self) -> None: + assert coerce_params(()) == () + + def test_empty_list_converted(self) -> None: + assert coerce_params([]) == () + + +# --------------------------------------------------------------------------- +# _strip_comments +# --------------------------------------------------------------------------- + + +class TestStripComments: + def test_no_comments_passthrough(self) -> None: + sql = "SELECT id FROM users WHERE id = 1" + assert _strip_comments(sql) == sql + + def test_line_comment_removed(self) -> None: + sql = "SELECT 1 -- get one" + assert _strip_comments(sql) == "SELECT 1 " + + def test_line_comment_preserves_newline(self) -> None: + sql = "SELECT 1 -- comment\nFROM t" + assert _strip_comments(sql) == "SELECT 1 \nFROM t" + + def test_line_comment_at_start(self) -> None: + sql = "-- full line comment\nSELECT 1" + assert _strip_comments(sql) == "\nSELECT 1" + + def test_multiple_line_comments(self) -> None: + sql = "SELECT 1 -- first\nFROM t -- second\nWHERE 1=1" + assert _strip_comments(sql) == "SELECT 1 \nFROM t \nWHERE 1=1" + + def test_block_comment_removed(self) -> None: + sql = "SELECT /* inline */ 1" + assert _strip_comments(sql) == "SELECT 1" + + def test_block_comment_replaced_with_space(self) -> None: + # Block comments become a single space to avoid token merging + sql = "SELECT/*comment*/1" + assert _strip_comments(sql) == "SELECT 1" + + def test_multiline_block_comment_removed(self) -> None: + sql = "SELECT /*\n big\n comment\n*/ 1" + assert _strip_comments(sql) == "SELECT 1" + + def test_string_literal_with_double_dash_preserved(self) -> None: + sql = "SELECT '-- not a comment' FROM t" + assert _strip_comments(sql) == sql + + def test_string_literal_with_block_comment_syntax_preserved(self) -> None: + sql = "SELECT '/* not a comment */' FROM t" + assert _strip_comments(sql) == sql + + def test_string_literal_adjacent_to_comment(self) -> None: + sql = "SELECT 'value' -- comment\nFROM t" + assert _strip_comments(sql) == "SELECT 'value' \nFROM t" + + def test_escaped_quote_in_string_literal(self) -> None: + sql = "SELECT 'it''s fine -- not a comment' FROM t" + assert _strip_comments(sql) == sql + + def test_unclosed_block_comment_strips_remainder(self) -> None: + sql = "SELECT 1 /* unclosed" + result = _strip_comments(sql) + assert "/*" not in result + assert result.startswith("SELECT 1 ") + + def test_line_comment_without_trailing_newline_strips_to_end(self) -> None: + sql = "SELECT 1 -- trailing only" + result = _strip_comments(sql) + assert "--" not in result + + +# --------------------------------------------------------------------------- +# _check_single_statement +# --------------------------------------------------------------------------- + + +class TestCheckSingleStatement: + def test_no_semicolon_passes(self) -> None: + _check_single_statement("SELECT 1") # no raise + + def test_trailing_semicolon_passes(self) -> None: + _check_single_statement("SELECT 1;") + + def test_trailing_semicolon_with_whitespace_passes(self) -> None: + _check_single_statement("SELECT 1; ") + + def test_trailing_semicolon_with_newline_passes(self) -> None: + _check_single_statement("SELECT 1;\n") + + def test_multiple_statements_raises(self) -> None: + with pytest.raises(SQLSanitizationError, match="Multiple SQL statements"): + _check_single_statement("SELECT 1; DROP TABLE users") + + def test_two_selects_raises(self) -> None: + with pytest.raises(SQLSanitizationError): + _check_single_statement("SELECT 1; SELECT 2") + + def test_no_space_between_statements_raises(self) -> None: + with pytest.raises(SQLSanitizationError): + _check_single_statement("SELECT 1;SELECT 2") + + def test_semicolon_in_string_literal_passes(self) -> None: + _check_single_statement("SELECT ';' FROM t") + + def test_multiple_semicolons_in_string_passes(self) -> None: + _check_single_statement("SELECT 'a;b;c' FROM t") + + def test_semicolon_after_string_then_statement_raises(self) -> None: + with pytest.raises(SQLSanitizationError): + _check_single_statement("SELECT ';' FROM t; DROP TABLE t") + + +# --------------------------------------------------------------------------- +# _check_verb +# --------------------------------------------------------------------------- + + +class TestCheckVerb: + ALLOWED = frozenset({"SELECT", "INSERT", "UPDATE", "DELETE"}) + + def test_allowed_verb_passes(self) -> None: + _check_verb("SELECT * FROM t", self.ALLOWED) + + def test_disallowed_verb_raises(self) -> None: + with pytest.raises(SQLSanitizationError, match="DROP"): + _check_verb("DROP TABLE users", self.ALLOWED) + + def test_lowercase_verb_matches_case_insensitively(self) -> None: + _check_verb("select * FROM t", self.ALLOWED) + + def test_mixed_case_verb_matches(self) -> None: + _check_verb("Select * FROM t", self.ALLOWED) + + def test_leading_whitespace_ignored(self) -> None: + _check_verb(" \n SELECT * FROM t", self.ALLOWED) + + def test_truncate_blocked(self) -> None: + with pytest.raises(SQLSanitizationError, match="TRUNCATE"): + _check_verb("TRUNCATE TABLE users", self.ALLOWED) + + def test_alter_blocked(self) -> None: + with pytest.raises(SQLSanitizationError, match="ALTER"): + _check_verb("ALTER TABLE users ADD COLUMN foo TEXT", self.ALLOWED) + + def test_with_cte_can_be_allowed(self) -> None: + allowed = self.ALLOWED | frozenset({"WITH"}) + _check_verb("WITH cte AS (SELECT 1) SELECT * FROM cte", allowed) + + def test_error_message_lists_allowed_verbs(self) -> None: + with pytest.raises(SQLSanitizationError, match="allowed"): + _check_verb("DROP TABLE t", frozenset({"SELECT"})) + + +# --------------------------------------------------------------------------- +# SQLSanitizer — configuration and defaults +# --------------------------------------------------------------------------- + + +class TestSQLSanitizerDefaults: + def test_default_strips_comments(self) -> None: + s = SQLSanitizer() + result = s.sanitize("SELECT 1 -- comment") + assert "--" not in result + + def test_default_blocks_multiple_statements(self) -> None: + s = SQLSanitizer() + with pytest.raises(SQLSanitizationError): + s.sanitize("SELECT 1; DROP TABLE t") + + def test_default_allows_any_verb(self) -> None: + s = SQLSanitizer() + s.sanitize("DROP TABLE t") # no raise — no verb restriction by default + + def test_returns_cleaned_sql_string(self) -> None: + s = SQLSanitizer() + result = s.sanitize("SELECT /* inline */ 1") + assert isinstance(result, str) + assert "/*" not in result + + +class TestSQLSanitizerFlags: + def test_strip_comments_false_preserves_comments(self) -> None: + s = SQLSanitizer(strip_comments=False) + sql = "SELECT 1 -- comment" + assert s.sanitize(sql) == sql + + def test_block_multiple_statements_false_allows_multi(self) -> None: + s = SQLSanitizer(block_multiple_statements=False) + # Should not raise + s.sanitize("SELECT 1; SELECT 2") + + def test_allowed_verbs_none_permits_any(self) -> None: + s = SQLSanitizer(allowed_verbs=None) + s.sanitize("DROP TABLE t") # no raise + + def test_allowed_verbs_set_blocks_others(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + with pytest.raises(SQLSanitizationError): + s.sanitize("DROP TABLE t") + + def test_allowed_verbs_set_permits_listed(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT", "INSERT"})) + s.sanitize("INSERT INTO t VALUES (1)") # no raise + + def test_all_checks_disabled(self) -> None: + s = SQLSanitizer( + strip_comments=False, + block_multiple_statements=False, + allowed_verbs=None, + ) + # Nothing should raise, nothing should be modified + dangerous = "DROP TABLE t; DELETE FROM users -- yolo" + assert s.sanitize(dangerous) == dangerous + + +# --------------------------------------------------------------------------- +# SQLSanitizer — ordering: comments stripped before statement check +# --------------------------------------------------------------------------- + + +class TestSQLSanitizerOrdering: + def test_comment_stripped_before_statement_check(self) -> None: + # After stripping, the trailing semicolon has only whitespace after it + s = SQLSanitizer() + sql = "SELECT 1; -- this is a comment, not a second statement" + # Should NOT raise: after stripping the comment, only "SELECT 1; " remains + result = s.sanitize(sql) + assert "SELECT 1" in result + + def test_comment_stripped_before_verb_check(self) -> None: + # A comment-only line before the real verb should not confuse the checker + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + sql = "-- preamble\nSELECT * FROM t" + s.sanitize(sql) # no raise + + def test_multiline_with_trailing_comment_passes(self) -> None: + s = SQLSanitizer() + sql = "SELECT *\nFROM users\nWHERE id = 1 -- primary key" + result = s.sanitize(sql) + assert "FROM users" in result + assert "--" not in result + + +# --------------------------------------------------------------------------- +# SQLSanitizer — error message quality +# --------------------------------------------------------------------------- + + +class TestSQLSanitizerErrorMessages: + def test_multiple_statements_error_message(self) -> None: + s = SQLSanitizer() + with pytest.raises(SQLSanitizationError) as exc_info: + s.sanitize("SELECT 1; DROP TABLE t") + assert "Multiple SQL statements" in str(exc_info.value) + + def test_verb_error_includes_offending_verb(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + with pytest.raises(SQLSanitizationError) as exc_info: + s.sanitize("TRUNCATE TABLE users") + assert "TRUNCATE" in str(exc_info.value) + + def test_verb_error_includes_allowed_list(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + with pytest.raises(SQLSanitizationError) as exc_info: + s.sanitize("DROP TABLE t") + assert "SELECT" in str(exc_info.value) + + def test_exception_is_subclass_of_execution_error(self) -> None: + from row_query.core.exceptions import ExecutionError + + s = SQLSanitizer() + with pytest.raises(ExecutionError): + s.sanitize("SELECT 1; DROP TABLE t")