From 324614770f2f8edb50fc92ddf250ef52d1acd1e1 Mon Sep 17 00:00:00 2001 From: xmaksutx Date: Tue, 17 Feb 2026 22:30:27 +0100 Subject: [PATCH 1/7] Add inline SQL support with configurable sanitization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Engine/AsyncEngine/TransactionManagers now accept raw SQL strings alongside registry keys; whitespace presence distinguishes the two - Positional params supported: scalar, list, or tuple are coerced to tuple before binding (dict remains named-param style) - New SQLSanitizer dataclass (strip_comments, block_multiple_statements, allowed_verbs) applied only to inline SQL; registry queries are trusted - SQLSanitizationError added to exception hierarchy - SQLite adapter: params or {} → params if params is not None else {} so positional tuples are not coerced to empty dict - 65 new unit tests covering all sanitizer code paths Co-Authored-By: Claude Sonnet 4.6 --- row_query/__init__.py | 5 + row_query/adapters/sqlite.py | 4 +- row_query/core/engine.py | 176 ++++++++++++----- row_query/core/exceptions.py | 7 + row_query/core/params.py | 26 +++ row_query/core/sanitizer.py | 163 ++++++++++++++++ row_query/core/transaction.py | 103 +++++++--- tests/unit/test_sanitizer.py | 346 ++++++++++++++++++++++++++++++++++ 8 files changed, 745 insertions(+), 85 deletions(-) create mode 100644 row_query/core/sanitizer.py create mode 100644 tests/unit/test_sanitizer.py diff --git a/row_query/__init__.py b/row_query/__init__.py index 47ad4ad..42fb129 100644 --- a/row_query/__init__.py +++ b/row_query/__init__.py @@ -26,12 +26,14 @@ QueryNotFoundError, RegistryError, RowQueryError, + SQLSanitizationError, StrictModeViolation, TransactionError, TransactionStateError, ) from row_query.core.migration import MigrationInfo, MigrationManager from row_query.core.registry import SQLRegistry +from row_query.core.sanitizer import SQLSanitizer from row_query.core.transaction import AsyncTransactionManager, TransactionManager from row_query.mapping.model import ModelMapper @@ -45,6 +47,8 @@ "AsyncEngine", # Registry "SQLRegistry", + # Sanitizer + "SQLSanitizer", # Transaction "TransactionManager", "AsyncTransactionManager", @@ -63,6 +67,7 @@ "ExecutionError", "MultipleRowsError", "ParameterBindingError", + "SQLSanitizationError", "MappingError", "ColumnMismatchError", "StrictModeViolation", diff --git a/row_query/adapters/sqlite.py b/row_query/adapters/sqlite.py index b64f975..4855798 100644 --- a/row_query/adapters/sqlite.py +++ b/row_query/adapters/sqlite.py @@ -50,7 +50,7 @@ def execute( params: dict[str, Any] | None = None, ) -> sqlite3.Cursor: """Execute SQL and return a cursor.""" - return connection.execute(sql, params or {}) + return connection.execute(sql, params if params is not None else {}) class SqliteAsyncAdapter: @@ -95,4 +95,4 @@ async def execute_async( params: dict[str, Any] | None = None, ) -> Any: """Execute SQL asynchronously and return a cursor.""" - return await connection.execute(sql, params or {}) + return await connection.execute(sql, params if params is not None else {}) diff --git a/row_query/core/engine.py b/row_query/core/engine.py index 8ce0238..98f80f5 100644 --- a/row_query/core/engine.py +++ b/row_query/core/engine.py @@ -13,13 +13,32 @@ MultipleRowsError, ParameterBindingError, ) -from row_query.core.params import normalize_params +from row_query.core.params import coerce_params, is_raw_sql, normalize_params from row_query.core.registry import SQLRegistry +from row_query.core.sanitizer import SQLSanitizer from row_query.core.transaction import AsyncTransactionManager, TransactionManager T = TypeVar("T") +def _resolve_sql( + query: str, + registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, +) -> tuple[str, str]: + """Return ``(sql_text, label)`` for *query*. + + If *query* is an inline SQL string (contains whitespace) it is returned + after optional sanitization. Otherwise it is looked up in *registry* by + name (registry queries are trusted and never sanitized). *label* is used + in error messages. + """ + if is_raw_sql(query): + sql = sanitizer.sanitize(query) if sanitizer is not None else query + return sql, "" + return registry.get(query), query + + def _rows_to_dicts(cursor: Any) -> list[dict[str, Any]]: """Convert cursor results to list of dicts. @@ -69,9 +88,11 @@ def __init__( self, connection_manager: ConnectionManager, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection_manager = connection_manager self._registry = registry + self._sanitizer = sanitizer self._paramstyle = connection_manager.adapter.paramstyle @classmethod @@ -79,46 +100,56 @@ def from_config( cls, config: Any, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> Engine: """Create an Engine from a ConnectionConfig and SQLRegistry. Args: config: ConnectionConfig instance registry: SQLRegistry instance + sanitizer: Optional SQLSanitizer applied to inline SQL strings. Returns: Engine instance """ connection_manager = ConnectionManager(config) - return cls(connection_manager, registry) + return cls(connection_manager, registry, sanitizer) def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: """Fetch a single row. + *query* may be a registry key (e.g. ``"users.get_by_id"``) or an + inline SQL string (e.g. ``"SELECT * FROM users WHERE id = ?"``). + + *params* may be a ``dict`` for named binding, a ``tuple``/``list`` for + positional binding, or a single scalar that is automatically wrapped in + a tuple. + Returns None if zero rows match. Raises MultipleRowsError if more than one row matches. """ - sql = self._registry.get(query_name) + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) if len(rows) == 0: return None if len(rows) > 1: - raise MultipleRowsError(query_name, len(rows)) + raise MultipleRowsError(label, len(rows)) row = rows[0] if mapper is not None: @@ -127,20 +158,25 @@ def fetch_one( def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: - """Fetch all matching rows.""" - sql = self._registry.get(query_name) + """Fetch all matching rows. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) @@ -150,18 +186,23 @@ def fetch_all( def fetch_scalar( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> Any: - """Fetch a single scalar value (first column of first row).""" - sql = self._registry.get(query_name) + """Fetch a single scalar value (first column of first row). + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e row = cursor.fetchone() if row is None: @@ -177,18 +218,23 @@ def fetch_scalar( def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query. Returns affected row count.""" - sql = self._registry.get(query_name) + """Execute a write query. Returns affected row count. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) with self._connection_manager.get_connection() as conn: try: - cursor = self._connection_manager.adapter.execute(conn, sql, params) + cursor = self._connection_manager.adapter.execute(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e conn.commit() return int(cursor.rowcount) @@ -204,6 +250,7 @@ def transaction(self) -> TransactionManager: adapter=self._connection_manager.adapter, registry=self._registry, pool=pool, + sanitizer=self._sanitizer, ) @@ -214,9 +261,11 @@ def __init__( self, connection_manager: AsyncConnectionManager, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection_manager = connection_manager self._registry = registry + self._sanitizer = sanitizer self._paramstyle = connection_manager.adapter.paramstyle @classmethod @@ -224,35 +273,42 @@ def from_config( cls, config: Any, registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> AsyncEngine: """Create an AsyncEngine from a ConnectionConfig and SQLRegistry. Args: config: ConnectionConfig instance registry: SQLRegistry instance + sanitizer: Optional SQLSanitizer applied to inline SQL strings. Returns: AsyncEngine instance """ connection_manager = AsyncConnectionManager(config) - return cls(connection_manager, registry) + return cls(connection_manager, registry, sanitizer) async def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: - """Fetch a single row asynchronously.""" - sql = self._registry.get(query_name) + """Fetch a single row asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return None @@ -268,7 +324,7 @@ async def fetch_one( if len(rows) == 0: return None if len(rows) > 1: - raise MultipleRowsError(query_name, len(rows)) + raise MultipleRowsError(label, len(rows)) row = rows[0] if mapper is not None: @@ -277,20 +333,25 @@ async def fetch_one( async def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, *, mapper: Any | None = None, ) -> Any: - """Fetch all matching rows asynchronously.""" - sql = self._registry.get(query_name) + """Fetch all matching rows asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return [] @@ -309,18 +370,23 @@ async def fetch_all( async def fetch_scalar( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> Any: - """Fetch a single scalar value asynchronously.""" - sql = self._registry.get(query_name) + """Fetch a single scalar value asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e row = await cursor.fetchone() if row is None: @@ -336,18 +402,23 @@ async def fetch_scalar( async def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query asynchronously.""" - sql = self._registry.get(query_name) + """Execute a write query asynchronously. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ + sql, label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) + bound = coerce_params(params) async with self._connection_manager.get_connection() as conn: try: - cursor = await self._connection_manager.adapter.execute_async(conn, sql, params) + cursor = await self._connection_manager.adapter.execute_async(conn, sql, bound) except Exception as e: - raise ParameterBindingError(query_name, str(e)) from e + raise ParameterBindingError(label, str(e)) from e await conn.commit() return int(cursor.rowcount) @@ -362,4 +433,5 @@ def transaction(self) -> AsyncTransactionManager: adapter=self._connection_manager.adapter, registry=self._registry, connection_manager=self._connection_manager, + sanitizer=self._sanitizer, ) diff --git a/row_query/core/exceptions.py b/row_query/core/exceptions.py index 0101332..313bc38 100644 --- a/row_query/core/exceptions.py +++ b/row_query/core/exceptions.py @@ -60,6 +60,13 @@ def __init__(self, query_name: str, detail: str) -> None: super().__init__(f"Parameter binding error for '{query_name}': {detail}") +class SQLSanitizationError(ExecutionError): + """Raised when an inline SQL string fails a sanitization check.""" + + def __init__(self, detail: str) -> None: + super().__init__(f"SQL sanitization failed: {detail}") + + # --- Mapping --- diff --git a/row_query/core/params.py b/row_query/core/params.py index 0cacdef..3725537 100644 --- a/row_query/core/params.py +++ b/row_query/core/params.py @@ -8,6 +8,7 @@ import re from functools import lru_cache +from typing import Any # Matches :name but not ::typecast and not inside words # Negative lookbehind for : (handles ::), \w (handles mid-word colons) @@ -53,3 +54,28 @@ def _convert_to_pyformat(sql: str) -> str: parts.append(_PARAM_PATTERN.sub(r"%(\1)s", sql[last_end:])) return "".join(parts) + + +def is_raw_sql(query: str) -> bool: + """Return True if query is an inline SQL string rather than a registry key. + + Registry keys use dot-notation (e.g. ``users.get_by_id``) and never + contain whitespace. Any SQL statement will contain at least one space. + """ + return any(c.isspace() for c in query) + + +def coerce_params( + params: dict[str, Any] | tuple[Any, ...] | list[Any] | Any, +) -> dict[str, Any] | tuple[Any, ...] | None: + """Normalize *params* to a dict, tuple, or None. + + * ``None`` / ``dict`` → returned as-is (named parameter binding). + * ``tuple`` / ``list`` → converted to ``tuple`` (positional binding). + * Any other scalar → wrapped in a single-element tuple. + """ + if params is None or isinstance(params, dict): + return params + if isinstance(params, (tuple, list)): + return tuple(params) + return (params,) diff --git a/row_query/core/sanitizer.py b/row_query/core/sanitizer.py new file mode 100644 index 0000000..8ccb2d6 --- /dev/null +++ b/row_query/core/sanitizer.py @@ -0,0 +1,163 @@ +"""Inline SQL sanitizer. + +Only applied to raw SQL strings passed directly to engine/transaction methods. +Queries loaded from the SQLRegistry are trusted and never sanitized. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +from row_query.core.exceptions import SQLSanitizationError + +# Matches the first SQL keyword (used for verb allow-listing) +_FIRST_KEYWORD = re.compile(r"^\s*(\w+)") + + +# --------------------------------------------------------------------------- +# Internal tokenizer +# --------------------------------------------------------------------------- + + +def _tokenize(sql: str) -> list[tuple[str, str]]: + """Split *sql* into ``('string', …)`` and ``('code', …)`` tokens. + + String literals (single-quoted, with ``''`` escapes) are preserved as-is. + Everything else is a ``'code'`` token. + """ + tokens: list[tuple[str, str]] = [] + i = 0 + n = len(sql) + last = 0 + + while i < n: + if sql[i] == "'": + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == "'": + j += 1 + if j >= n or sql[j] != "'": + break # end of literal + j += 1 # '' escape — continue + else: + j += 1 + tokens.append(("string", sql[i:j])) + last = j + i = j + else: + i += 1 + + if last < n: + tokens.append(("code", sql[last:])) + + return tokens + + +def _strip_comments_in_code(code: str) -> str: + """Remove ``--`` line comments and ``/* */`` block comments from a code segment.""" + result: list[str] = [] + i = 0 + n = len(code) + + while i < n: + if code[i : i + 2] == "--": + j = code.find("\n", i) + if j == -1: + break + result.append("\n") + i = j + 1 + elif code[i : i + 2] == "/*": + j = code.find("*/", i + 2) + if j == -1: + break + result.append(" ") + i = j + 2 + else: + result.append(code[i]) + i += 1 + + return "".join(result) + + +# --------------------------------------------------------------------------- +# Individual sanitization checks +# --------------------------------------------------------------------------- + + +def _strip_comments(sql: str) -> str: + """Remove SQL comments while preserving string literals.""" + parts: list[str] = [] + for kind, content in _tokenize(sql): + if kind == "string": + parts.append(content) + else: + parts.append(_strip_comments_in_code(content)) + return "".join(parts) + + +def _check_single_statement(sql: str) -> None: + """Raise if *sql* contains a semicolon followed by non-whitespace content.""" + for kind, content in _tokenize(sql): + if kind == "string": + continue + for i, ch in enumerate(content): + if ch == ";" and content[i + 1 :].strip(): + raise SQLSanitizationError( + "Multiple SQL statements are not permitted in inline SQL" + ) + + +def _check_verb(sql: str, allowed: frozenset[str]) -> None: + """Raise if the leading SQL keyword is not in *allowed*.""" + m = _FIRST_KEYWORD.match(sql) + if m: + verb = m.group(1).upper() + if verb not in allowed: + raise SQLSanitizationError( + f"SQL verb '{verb}' is not permitted; " + f"allowed: {sorted(allowed)}" + ) + + +# --------------------------------------------------------------------------- +# Public class +# --------------------------------------------------------------------------- + + +@dataclass +class SQLSanitizer: + """Configurable sanitizer for inline SQL strings. + + Applied only to raw SQL passed directly to engine/transaction methods. + Registry-loaded queries are always trusted and never sanitized. + + Attributes: + strip_comments: Strip ``--`` and ``/* */`` comments before execution. + block_multiple_statements: Reject SQL that contains a statement- + terminating ``;`` followed by additional content (prevents query + stacking such as ``SELECT 1; DROP TABLE users``). + allowed_verbs: If not ``None``, only SQL statements whose first keyword + appears in this set are permitted. ``None`` means no restriction. + Example: ``frozenset({"SELECT", "INSERT", "UPDATE", "DELETE"})``. + """ + + strip_comments: bool = True + block_multiple_statements: bool = True + allowed_verbs: frozenset[str] | None = None + + def sanitize(self, sql: str) -> str: + """Apply all configured checks to *sql* and return the (cleaned) SQL. + + Raises: + SQLSanitizationError: If any enabled check fails. + """ + if self.strip_comments: + sql = _strip_comments(sql) + if self.block_multiple_statements: + _check_single_statement(sql) + if self.allowed_verbs is not None: + _check_verb(sql, self.allowed_verbs) + return sql diff --git a/row_query/core/transaction.py b/row_query/core/transaction.py index 94bd1d1..b0c2bf8 100644 --- a/row_query/core/transaction.py +++ b/row_query/core/transaction.py @@ -10,8 +10,21 @@ from typing import Any from row_query.core.exceptions import TransactionStateError -from row_query.core.params import normalize_params +from row_query.core.params import coerce_params, is_raw_sql, normalize_params from row_query.core.registry import SQLRegistry +from row_query.core.sanitizer import SQLSanitizer + + +def _resolve_sql( + query: str, + registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, +) -> tuple[str, str]: + """Return ``(sql_text, label)`` for *query* (raw SQL or registry key).""" + if is_raw_sql(query): + sql = sanitizer.sanitize(query) if sanitizer is not None else query + return sql, "" + return registry.get(query), query class _TxState(Enum): @@ -51,11 +64,13 @@ def __init__( adapter: Any, registry: SQLRegistry, pool: Any = None, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection = connection self._adapter = adapter self._registry = registry self._pool = pool + self._sanitizer = sanitizer self._paramstyle: str = adapter.paramstyle self._state = _TxState.IDLE @@ -83,26 +98,34 @@ def __exit__( def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query within this transaction.""" + """Execute a write query within this transaction. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, _label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, params) + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) return int(cursor.rowcount) def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> dict[str, Any] | None: - """Fetch a single row within transaction context.""" + """Fetch a single row within transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, _label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, params) + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) rows = _rows_to_dicts(cursor) if not rows: return None @@ -110,14 +133,18 @@ def fetch_one( def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> list[dict[str, Any]]: - """Fetch all rows within transaction context.""" + """Fetch all rows within transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, _label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, params) + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) return _rows_to_dicts(cursor) def commit(self) -> None: @@ -155,11 +182,13 @@ def __init__( connection_manager: Any, connection: Any = None, pool: Any = None, + sanitizer: SQLSanitizer | None = None, ) -> None: self._connection = connection self._adapter = adapter self._registry = registry self._pool = pool + self._sanitizer = sanitizer self._connection_manager = connection_manager self._paramstyle: str = adapter.paramstyle self._state = _TxState.IDLE @@ -194,26 +223,34 @@ async def __aexit__( async def execute( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> int: - """Execute a write query within this async transaction.""" + """Execute a write query within this async transaction. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, _label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, params) + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) return int(cursor.rowcount) async def fetch_one( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> dict[str, Any] | None: - """Fetch a single row within async transaction context.""" + """Fetch a single row within async transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, _label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, params) + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) if cursor.description is None: return None columns = [desc[0] for desc in cursor.description] @@ -231,14 +268,18 @@ async def fetch_one( async def fetch_all( self, - query_name: str, - params: dict[str, Any] | None = None, + query: str, + params: Any = None, ) -> list[dict[str, Any]]: - """Fetch all rows within async transaction context.""" + """Fetch all rows within async transaction context. + + *query* may be a registry key or an inline SQL string. + *params* may be a dict, tuple/list, or scalar. + """ self._check_active() - sql = self._registry.get(query_name) + sql, _label = _resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, params) + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) if cursor.description is None: return [] columns = [desc[0] for desc in cursor.description] diff --git a/tests/unit/test_sanitizer.py b/tests/unit/test_sanitizer.py new file mode 100644 index 0000000..3782839 --- /dev/null +++ b/tests/unit/test_sanitizer.py @@ -0,0 +1,346 @@ +"""Unit tests for SQLSanitizer and related helpers.""" + +from __future__ import annotations + +import pytest + +from row_query.core.exceptions import SQLSanitizationError +from row_query.core.params import coerce_params, is_raw_sql +from row_query.core.sanitizer import ( + SQLSanitizer, + _check_single_statement, + _check_verb, + _strip_comments, +) + + +# --------------------------------------------------------------------------- +# is_raw_sql +# --------------------------------------------------------------------------- + + +class TestIsRawSql: + def test_sql_with_space_is_raw(self) -> None: + assert is_raw_sql("SELECT 1") is True + + def test_sql_with_newline_is_raw(self) -> None: + assert is_raw_sql("SELECT\n1") is True + + def test_sql_with_tab_is_raw(self) -> None: + assert is_raw_sql("SELECT\t1") is True + + def test_registry_key_is_not_raw(self) -> None: + assert is_raw_sql("users.get_by_id") is False + + def test_dotted_registry_key_is_not_raw(self) -> None: + assert is_raw_sql("billing.invoice.list") is False + + def test_empty_string_is_not_raw(self) -> None: + assert is_raw_sql("") is False + + +# --------------------------------------------------------------------------- +# coerce_params +# --------------------------------------------------------------------------- + + +class TestCoerceParams: + def test_none_passthrough(self) -> None: + assert coerce_params(None) is None + + def test_dict_passthrough(self) -> None: + p = {"id": 1} + assert coerce_params(p) is p + + def test_tuple_passthrough(self) -> None: + p = (1, 2) + assert coerce_params(p) == (1, 2) + + def test_list_converted_to_tuple(self) -> None: + assert coerce_params([1, 2]) == (1, 2) + + def test_int_scalar_wrapped(self) -> None: + assert coerce_params(42) == (42,) + + def test_str_scalar_wrapped(self) -> None: + assert coerce_params("hello") == ("hello",) + + def test_bool_scalar_wrapped(self) -> None: + assert coerce_params(True) == (True,) + + def test_empty_tuple_passthrough(self) -> None: + assert coerce_params(()) == () + + def test_empty_list_converted(self) -> None: + assert coerce_params([]) == () + + +# --------------------------------------------------------------------------- +# _strip_comments +# --------------------------------------------------------------------------- + + +class TestStripComments: + def test_no_comments_passthrough(self) -> None: + sql = "SELECT id FROM users WHERE id = 1" + assert _strip_comments(sql) == sql + + def test_line_comment_removed(self) -> None: + sql = "SELECT 1 -- get one" + assert _strip_comments(sql) == "SELECT 1 " + + def test_line_comment_preserves_newline(self) -> None: + sql = "SELECT 1 -- comment\nFROM t" + assert _strip_comments(sql) == "SELECT 1 \nFROM t" + + def test_line_comment_at_start(self) -> None: + sql = "-- full line comment\nSELECT 1" + assert _strip_comments(sql) == "\nSELECT 1" + + def test_multiple_line_comments(self) -> None: + sql = "SELECT 1 -- first\nFROM t -- second\nWHERE 1=1" + assert _strip_comments(sql) == "SELECT 1 \nFROM t \nWHERE 1=1" + + def test_block_comment_removed(self) -> None: + sql = "SELECT /* inline */ 1" + assert _strip_comments(sql) == "SELECT 1" + + def test_block_comment_replaced_with_space(self) -> None: + # Block comments become a single space to avoid token merging + sql = "SELECT/*comment*/1" + assert _strip_comments(sql) == "SELECT 1" + + def test_multiline_block_comment_removed(self) -> None: + sql = "SELECT /*\n big\n comment\n*/ 1" + assert _strip_comments(sql) == "SELECT 1" + + def test_string_literal_with_double_dash_preserved(self) -> None: + sql = "SELECT '-- not a comment' FROM t" + assert _strip_comments(sql) == sql + + def test_string_literal_with_block_comment_syntax_preserved(self) -> None: + sql = "SELECT '/* not a comment */' FROM t" + assert _strip_comments(sql) == sql + + def test_string_literal_adjacent_to_comment(self) -> None: + sql = "SELECT 'value' -- comment\nFROM t" + assert _strip_comments(sql) == "SELECT 'value' \nFROM t" + + def test_escaped_quote_in_string_literal(self) -> None: + sql = "SELECT 'it''s fine -- not a comment' FROM t" + assert _strip_comments(sql) == sql + + def test_unclosed_block_comment_strips_remainder(self) -> None: + sql = "SELECT 1 /* unclosed" + result = _strip_comments(sql) + assert "/*" not in result + assert result.startswith("SELECT 1 ") + + def test_line_comment_without_trailing_newline_strips_to_end(self) -> None: + sql = "SELECT 1 -- trailing only" + result = _strip_comments(sql) + assert "--" not in result + + +# --------------------------------------------------------------------------- +# _check_single_statement +# --------------------------------------------------------------------------- + + +class TestCheckSingleStatement: + def test_no_semicolon_passes(self) -> None: + _check_single_statement("SELECT 1") # no raise + + def test_trailing_semicolon_passes(self) -> None: + _check_single_statement("SELECT 1;") + + def test_trailing_semicolon_with_whitespace_passes(self) -> None: + _check_single_statement("SELECT 1; ") + + def test_trailing_semicolon_with_newline_passes(self) -> None: + _check_single_statement("SELECT 1;\n") + + def test_multiple_statements_raises(self) -> None: + with pytest.raises(SQLSanitizationError, match="Multiple SQL statements"): + _check_single_statement("SELECT 1; DROP TABLE users") + + def test_two_selects_raises(self) -> None: + with pytest.raises(SQLSanitizationError): + _check_single_statement("SELECT 1; SELECT 2") + + def test_no_space_between_statements_raises(self) -> None: + with pytest.raises(SQLSanitizationError): + _check_single_statement("SELECT 1;SELECT 2") + + def test_semicolon_in_string_literal_passes(self) -> None: + _check_single_statement("SELECT ';' FROM t") + + def test_multiple_semicolons_in_string_passes(self) -> None: + _check_single_statement("SELECT 'a;b;c' FROM t") + + def test_semicolon_after_string_then_statement_raises(self) -> None: + with pytest.raises(SQLSanitizationError): + _check_single_statement("SELECT ';' FROM t; DROP TABLE t") + + +# --------------------------------------------------------------------------- +# _check_verb +# --------------------------------------------------------------------------- + + +class TestCheckVerb: + ALLOWED = frozenset({"SELECT", "INSERT", "UPDATE", "DELETE"}) + + def test_allowed_verb_passes(self) -> None: + _check_verb("SELECT * FROM t", self.ALLOWED) + + def test_disallowed_verb_raises(self) -> None: + with pytest.raises(SQLSanitizationError, match="DROP"): + _check_verb("DROP TABLE users", self.ALLOWED) + + def test_lowercase_verb_matches_case_insensitively(self) -> None: + _check_verb("select * FROM t", self.ALLOWED) + + def test_mixed_case_verb_matches(self) -> None: + _check_verb("Select * FROM t", self.ALLOWED) + + def test_leading_whitespace_ignored(self) -> None: + _check_verb(" \n SELECT * FROM t", self.ALLOWED) + + def test_truncate_blocked(self) -> None: + with pytest.raises(SQLSanitizationError, match="TRUNCATE"): + _check_verb("TRUNCATE TABLE users", self.ALLOWED) + + def test_alter_blocked(self) -> None: + with pytest.raises(SQLSanitizationError, match="ALTER"): + _check_verb("ALTER TABLE users ADD COLUMN foo TEXT", self.ALLOWED) + + def test_with_cte_can_be_allowed(self) -> None: + allowed = self.ALLOWED | frozenset({"WITH"}) + _check_verb("WITH cte AS (SELECT 1) SELECT * FROM cte", allowed) + + def test_error_message_lists_allowed_verbs(self) -> None: + with pytest.raises(SQLSanitizationError, match="allowed"): + _check_verb("DROP TABLE t", frozenset({"SELECT"})) + + +# --------------------------------------------------------------------------- +# SQLSanitizer — configuration and defaults +# --------------------------------------------------------------------------- + + +class TestSQLSanitizerDefaults: + def test_default_strips_comments(self) -> None: + s = SQLSanitizer() + result = s.sanitize("SELECT 1 -- comment") + assert "--" not in result + + def test_default_blocks_multiple_statements(self) -> None: + s = SQLSanitizer() + with pytest.raises(SQLSanitizationError): + s.sanitize("SELECT 1; DROP TABLE t") + + def test_default_allows_any_verb(self) -> None: + s = SQLSanitizer() + s.sanitize("DROP TABLE t") # no raise — no verb restriction by default + + def test_returns_cleaned_sql_string(self) -> None: + s = SQLSanitizer() + result = s.sanitize("SELECT /* inline */ 1") + assert isinstance(result, str) + assert "/*" not in result + + +class TestSQLSanitizerFlags: + def test_strip_comments_false_preserves_comments(self) -> None: + s = SQLSanitizer(strip_comments=False) + sql = "SELECT 1 -- comment" + assert s.sanitize(sql) == sql + + def test_block_multiple_statements_false_allows_multi(self) -> None: + s = SQLSanitizer(block_multiple_statements=False) + # Should not raise + s.sanitize("SELECT 1; SELECT 2") + + def test_allowed_verbs_none_permits_any(self) -> None: + s = SQLSanitizer(allowed_verbs=None) + s.sanitize("DROP TABLE t") # no raise + + def test_allowed_verbs_set_blocks_others(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + with pytest.raises(SQLSanitizationError): + s.sanitize("DROP TABLE t") + + def test_allowed_verbs_set_permits_listed(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT", "INSERT"})) + s.sanitize("INSERT INTO t VALUES (1)") # no raise + + def test_all_checks_disabled(self) -> None: + s = SQLSanitizer( + strip_comments=False, + block_multiple_statements=False, + allowed_verbs=None, + ) + # Nothing should raise, nothing should be modified + dangerous = "DROP TABLE t; DELETE FROM users -- yolo" + assert s.sanitize(dangerous) == dangerous + + +# --------------------------------------------------------------------------- +# SQLSanitizer — ordering: comments stripped before statement check +# --------------------------------------------------------------------------- + + +class TestSQLSanitizerOrdering: + def test_comment_stripped_before_statement_check(self) -> None: + # After stripping, the trailing semicolon has only whitespace after it + s = SQLSanitizer() + sql = "SELECT 1; -- this is a comment, not a second statement" + # Should NOT raise: after stripping the comment, only "SELECT 1; " remains + result = s.sanitize(sql) + assert "SELECT 1" in result + + def test_comment_stripped_before_verb_check(self) -> None: + # A comment-only line before the real verb should not confuse the checker + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + sql = "-- preamble\nSELECT * FROM t" + s.sanitize(sql) # no raise + + def test_multiline_with_trailing_comment_passes(self) -> None: + s = SQLSanitizer() + sql = "SELECT *\nFROM users\nWHERE id = 1 -- primary key" + result = s.sanitize(sql) + assert "FROM users" in result + assert "--" not in result + + +# --------------------------------------------------------------------------- +# SQLSanitizer — error message quality +# --------------------------------------------------------------------------- + + +class TestSQLSanitizerErrorMessages: + def test_multiple_statements_error_message(self) -> None: + s = SQLSanitizer() + with pytest.raises(SQLSanitizationError) as exc_info: + s.sanitize("SELECT 1; DROP TABLE t") + assert "Multiple SQL statements" in str(exc_info.value) + + def test_verb_error_includes_offending_verb(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + with pytest.raises(SQLSanitizationError) as exc_info: + s.sanitize("TRUNCATE TABLE users") + assert "TRUNCATE" in str(exc_info.value) + + def test_verb_error_includes_allowed_list(self) -> None: + s = SQLSanitizer(allowed_verbs=frozenset({"SELECT"})) + with pytest.raises(SQLSanitizationError) as exc_info: + s.sanitize("DROP TABLE t") + assert "SELECT" in str(exc_info.value) + + def test_exception_is_subclass_of_execution_error(self) -> None: + from row_query.core.exceptions import ExecutionError + + s = SQLSanitizer() + with pytest.raises(ExecutionError): + s.sanitize("SELECT 1; DROP TABLE t") From a3ffdbf51464b7abaf6b955d60bc005ac79d2b66 Mon Sep 17 00:00:00 2001 From: xmaksutx Date: Tue, 17 Feb 2026 22:37:08 +0100 Subject: [PATCH 2/7] Fix ruff formatting in sanitizer.py Co-Authored-By: Claude Sonnet 4.6 --- row_query/core/sanitizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/row_query/core/sanitizer.py b/row_query/core/sanitizer.py index 8ccb2d6..6517736 100644 --- a/row_query/core/sanitizer.py +++ b/row_query/core/sanitizer.py @@ -117,8 +117,7 @@ def _check_verb(sql: str, allowed: frozenset[str]) -> None: verb = m.group(1).upper() if verb not in allowed: raise SQLSanitizationError( - f"SQL verb '{verb}' is not permitted; " - f"allowed: {sorted(allowed)}" + f"SQL verb '{verb}' is not permitted; allowed: {sorted(allowed)}" ) From cadd1361b272242e46a11aedbc90aaec34cf392a Mon Sep 17 00:00:00 2001 From: xmaksutx Date: Tue, 17 Feb 2026 22:37:43 +0100 Subject: [PATCH 3/7] Fix ruff import ordering in test_sanitizer.py Co-Authored-By: Claude Sonnet 4.6 --- tests/unit/test_sanitizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_sanitizer.py b/tests/unit/test_sanitizer.py index 3782839..aa2087d 100644 --- a/tests/unit/test_sanitizer.py +++ b/tests/unit/test_sanitizer.py @@ -13,7 +13,6 @@ _strip_comments, ) - # --------------------------------------------------------------------------- # is_raw_sql # --------------------------------------------------------------------------- From 94a7b23cfa1ec0506683e3390e0dc1ac4dfd6735 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 21:59:22 +0000 Subject: [PATCH 4/7] Initial plan From 8502291ca4c1a31f76b91202ae88a01dd82b829b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 22:04:06 +0000 Subject: [PATCH 5/7] Move _resolve_sql to shared location and add error handling to transaction methods Co-authored-by: MaksimShevtsov <20194438+MaksimShevtsov@users.noreply.github.com> --- row_query/core/engine.py | 36 ++++------------ row_query/core/params.py | 43 ++++++++++++++++++- row_query/core/registry.py | 10 +++++ row_query/core/sanitizer.py | 80 +++++++++++++++++++++++++++++++++-- row_query/core/transaction.py | 58 +++++++++++++------------ 5 files changed, 169 insertions(+), 58 deletions(-) diff --git a/row_query/core/engine.py b/row_query/core/engine.py index 98f80f5..0cdb372 100644 --- a/row_query/core/engine.py +++ b/row_query/core/engine.py @@ -13,7 +13,7 @@ MultipleRowsError, ParameterBindingError, ) -from row_query.core.params import coerce_params, is_raw_sql, normalize_params +from row_query.core.params import coerce_params, normalize_params, resolve_sql from row_query.core.registry import SQLRegistry from row_query.core.sanitizer import SQLSanitizer from row_query.core.transaction import AsyncTransactionManager, TransactionManager @@ -21,24 +21,6 @@ T = TypeVar("T") -def _resolve_sql( - query: str, - registry: SQLRegistry, - sanitizer: SQLSanitizer | None = None, -) -> tuple[str, str]: - """Return ``(sql_text, label)`` for *query*. - - If *query* is an inline SQL string (contains whitespace) it is returned - after optional sanitization. Otherwise it is looked up in *registry* by - name (registry queries are trusted and never sanitized). *label* is used - in error messages. - """ - if is_raw_sql(query): - sql = sanitizer.sanitize(query) if sanitizer is not None else query - return sql, "" - return registry.get(query), query - - def _rows_to_dicts(cursor: Any) -> list[dict[str, Any]]: """Convert cursor results to list of dicts. @@ -134,7 +116,7 @@ def fetch_one( Returns None if zero rows match. Raises MultipleRowsError if more than one row matches. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -168,7 +150,7 @@ def fetch_all( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -194,7 +176,7 @@ def fetch_scalar( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -226,7 +208,7 @@ def execute( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -300,7 +282,7 @@ async def fetch_one( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -343,7 +325,7 @@ async def fetch_all( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -378,7 +360,7 @@ async def fetch_scalar( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -410,7 +392,7 @@ async def execute( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) diff --git a/row_query/core/params.py b/row_query/core/params.py index 3725537..7fe0d22 100644 --- a/row_query/core/params.py +++ b/row_query/core/params.py @@ -8,7 +8,11 @@ import re from functools import lru_cache -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from row_query.core.registry import SQLRegistry + from row_query.core.sanitizer import SQLSanitizer # Matches :name but not ::typecast and not inside words # Negative lookbehind for : (handles ::), \w (handles mid-word colons) @@ -61,6 +65,9 @@ def is_raw_sql(query: str) -> bool: Registry keys use dot-notation (e.g. ``users.get_by_id``) and never contain whitespace. Any SQL statement will contain at least one space. + + Note: Registry keys are validated during registration to ensure they do + not contain whitespace, preventing ambiguity. """ return any(c.isspace() for c in query) @@ -73,9 +80,43 @@ def coerce_params( * ``None`` / ``dict`` → returned as-is (named parameter binding). * ``tuple`` / ``list`` → converted to ``tuple`` (positional binding). * Any other scalar → wrapped in a single-element tuple. + + Note on parameter styles: + Registry queries use `:name` style parameters (converted to driver format). + Inline SQL can use either `:name` or `?`-style placeholders depending on + the database driver. When using inline SQL with positional parameters, + ensure compatibility with your target database (SQLite uses `?`, PostgreSQL + uses `$1`, etc.). """ if params is None or isinstance(params, dict): return params if isinstance(params, (tuple, list)): return tuple(params) return (params,) + + +def resolve_sql( + query: str, + registry: "SQLRegistry", + sanitizer: "SQLSanitizer | None" = None, +) -> tuple[str, str]: + """Return ``(sql_text, label)`` for *query*. + + If *query* is an inline SQL string (contains whitespace) it is returned + after optional sanitization. Otherwise it is looked up in *registry* by + name (registry queries are trusted and never sanitized). *label* is used + in error messages. + + Args: + query: Either a registry key (e.g. "users.get_by_id") or inline SQL. + registry: SQLRegistry instance for looking up named queries. + sanitizer: Optional SQLSanitizer applied only to inline SQL strings. + + Returns: + Tuple of (sql_text, label) where label is "" for inline SQL + or the registry key for named queries. + """ + if is_raw_sql(query): + sql = sanitizer.sanitize(query) if sanitizer is not None else query + return sql, "" + return registry.get(query), query diff --git a/row_query/core/registry.py b/row_query/core/registry.py index aa031f4..8133494 100644 --- a/row_query/core/registry.py +++ b/row_query/core/registry.py @@ -43,6 +43,16 @@ def _load(self) -> None: parts[-1] = parts[-1].removesuffix(".sql") query_name = ".".join(parts) + # Validate that query_name doesn't contain whitespace + # This prevents ambiguity with inline SQL detection + if any(c.isspace() for c in query_name): + from row_query.core.exceptions import ExecutionError + raise ExecutionError( + f"Registry key '{query_name}' from file '{sql_file}' contains " + f"whitespace, which is not allowed. Registry keys must not contain " + f"spaces, tabs, or newlines to avoid ambiguity with inline SQL." + ) + if query_name in self._queries: raise DuplicateQueryError( query_name, diff --git a/row_query/core/sanitizer.py b/row_query/core/sanitizer.py index 6517736..93f5b6c 100644 --- a/row_query/core/sanitizer.py +++ b/row_query/core/sanitizer.py @@ -21,10 +21,15 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: - """Split *sql* into ``('string', …)`` and ``('code', …)`` tokens. + """Split *sql* into ``('string', …)``, ``('identifier', …)``, and ``('code', …)`` tokens. String literals (single-quoted, with ``''`` escapes) are preserved as-is. + Identifiers (double-quoted for PostgreSQL/MySQL ANSI_QUOTES, backtick-quoted + for MySQL) are also preserved to avoid stripping comment-like syntax inside them. Everything else is a ``'code'`` token. + + Raises: + SQLSanitizationError: If an unterminated string literal or identifier is detected. """ tokens: list[tuple[str, str]] = [] i = 0 @@ -32,6 +37,7 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: last = 0 while i < n: + # Single-quoted string literal if sql[i] == "'": if i > last: tokens.append(("code", sql[last:i])) @@ -44,9 +50,59 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: j += 1 # '' escape — continue else: j += 1 + # Check for unterminated string + if j >= n and (j == i + 1 or sql[j - 1] != "'"): + from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( + "Unterminated string literal detected in SQL" + ) tokens.append(("string", sql[i:j])) last = j i = j + # Double-quoted identifier (PostgreSQL, MySQL ANSI_QUOTES) + elif sql[i] == '"': + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == '"': + j += 1 + if j >= n or sql[j] != '"': + break # end of identifier + j += 1 # "" escape — continue + else: + j += 1 + # Check for unterminated identifier + if j >= n and (j == i + 1 or sql[j - 1] != '"'): + from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( + "Unterminated double-quoted identifier detected in SQL" + ) + tokens.append(("identifier", sql[i:j])) + last = j + i = j + # Backtick-quoted identifier (MySQL) + elif sql[i] == "`": + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == "`": + j += 1 + if j >= n or sql[j] != "`": + break # end of identifier + j += 1 # `` escape — continue + else: + j += 1 + # Check for unterminated identifier + if j >= n and (j == i + 1 or sql[j - 1] != "`"): + from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( + "Unterminated backtick-quoted identifier detected in SQL" + ) + tokens.append(("identifier", sql[i:j])) + last = j + i = j else: i += 1 @@ -88,10 +144,10 @@ def _strip_comments_in_code(code: str) -> str: def _strip_comments(sql: str) -> str: - """Remove SQL comments while preserving string literals.""" + """Remove SQL comments while preserving string literals and identifiers.""" parts: list[str] = [] for kind, content in _tokenize(sql): - if kind == "string": + if kind in ("string", "identifier"): parts.append(content) else: parts.append(_strip_comments_in_code(content)) @@ -101,7 +157,7 @@ def _strip_comments(sql: str) -> str: def _check_single_statement(sql: str) -> None: """Raise if *sql* contains a semicolon followed by non-whitespace content.""" for kind, content in _tokenize(sql): - if kind == "string": + if kind in ("string", "identifier"): continue for i, ch in enumerate(content): if ch == ";" and content[i + 1 :].strip(): @@ -133,6 +189,22 @@ class SQLSanitizer: Applied only to raw SQL passed directly to engine/transaction methods. Registry-loaded queries are always trusted and never sanitized. + **IMPORTANT SECURITY WARNING:** + This sanitizer does NOT protect against SQL injection if user-provided + data is concatenated directly into SQL strings. You MUST use parameterized + queries with placeholders (e.g., `?` or `:name`) to prevent SQL injection. + The sanitizer only provides defense-in-depth measures (comment stripping, + statement blocking, verb restrictions) but is NOT a substitute for proper + parameterization. + + Example of UNSAFE code: + # NEVER DO THIS - vulnerable to SQL injection + engine.fetch_all(f"SELECT * FROM users WHERE name = '{user_input}'") + + Example of SAFE code: + # ALWAYS USE THIS - parameterized query + engine.fetch_all("SELECT * FROM users WHERE name = ?", user_input) + Attributes: strip_comments: Strip ``--`` and ``/* */`` comments before execution. block_multiple_statements: Reject SQL that contains a statement- diff --git a/row_query/core/transaction.py b/row_query/core/transaction.py index b0c2bf8..62a0c4b 100644 --- a/row_query/core/transaction.py +++ b/row_query/core/transaction.py @@ -9,24 +9,12 @@ from enum import Enum from typing import Any -from row_query.core.exceptions import TransactionStateError -from row_query.core.params import coerce_params, is_raw_sql, normalize_params +from row_query.core.exceptions import ParameterBindingError, TransactionStateError +from row_query.core.params import coerce_params, normalize_params, resolve_sql from row_query.core.registry import SQLRegistry from row_query.core.sanitizer import SQLSanitizer -def _resolve_sql( - query: str, - registry: SQLRegistry, - sanitizer: SQLSanitizer | None = None, -) -> tuple[str, str]: - """Return ``(sql_text, label)`` for *query* (raw SQL or registry key).""" - if is_raw_sql(query): - sql = sanitizer.sanitize(query) if sanitizer is not None else query - return sql, "" - return registry.get(query), query - - class _TxState(Enum): IDLE = "idle" ACTIVE = "active" @@ -107,9 +95,12 @@ def execute( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) def fetch_one( @@ -123,9 +114,12 @@ def fetch_one( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) if not rows: return None @@ -142,9 +136,12 @@ def fetch_all( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return _rows_to_dicts(cursor) def commit(self) -> None: @@ -232,9 +229,12 @@ async def execute( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + try: + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) async def fetch_one( @@ -248,9 +248,12 @@ async def fetch_one( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + try: + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return None columns = [desc[0] for desc in cursor.description] @@ -277,9 +280,12 @@ async def fetch_all( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + try: + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return [] columns = [desc[0] for desc in cursor.description] From 9b3b86ee4cc3e152899ebdb8c8a8a3479c4f22db Mon Sep 17 00:00:00 2001 From: xmaksutx Date: Tue, 17 Feb 2026 23:10:55 +0100 Subject: [PATCH 6/7] Refactor async transaction methods for improved readability by formatting SQL execution calls --- row_query/core/params.py | 4 ++-- row_query/core/transaction.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/row_query/core/params.py b/row_query/core/params.py index 7fe0d22..a78c0f0 100644 --- a/row_query/core/params.py +++ b/row_query/core/params.py @@ -97,8 +97,8 @@ def coerce_params( def resolve_sql( query: str, - registry: "SQLRegistry", - sanitizer: "SQLSanitizer | None" = None, + registry: SQLRegistry, + sanitizer: SQLSanitizer | None = None, ) -> tuple[str, str]: """Return ``(sql_text, label)`` for *query*. diff --git a/row_query/core/transaction.py b/row_query/core/transaction.py index 62a0c4b..8ebd140 100644 --- a/row_query/core/transaction.py +++ b/row_query/core/transaction.py @@ -232,7 +232,11 @@ async def execute( sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) try: - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + cursor = await self._adapter.execute_async( + self._connection, + sql, + coerce_params(params), + ) except Exception as e: raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) @@ -251,7 +255,11 @@ async def fetch_one( sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) try: - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + cursor = await self._adapter.execute_async( + self._connection, + sql, + coerce_params(params), + ) except Exception as e: raise ParameterBindingError(label, str(e)) from e if cursor.description is None: @@ -283,7 +291,11 @@ async def fetch_all( sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) try: - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + cursor = await self._adapter.execute_async( + self._connection, + sql, + coerce_params(params), + ) except Exception as e: raise ParameterBindingError(label, str(e)) from e if cursor.description is None: From 55c8387741d4d18fa4080228732ca66fc24e1fa8 Mon Sep 17 00:00:00 2001 From: xmaksutx Date: Tue, 17 Feb 2026 23:12:11 +0100 Subject: [PATCH 7/7] Improve error handling in SQL sanitization and registry key validation --- row_query/core/registry.py | 1 + row_query/core/sanitizer.py | 11 +++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/row_query/core/registry.py b/row_query/core/registry.py index 8133494..9e18f0c 100644 --- a/row_query/core/registry.py +++ b/row_query/core/registry.py @@ -47,6 +47,7 @@ def _load(self) -> None: # This prevents ambiguity with inline SQL detection if any(c.isspace() for c in query_name): from row_query.core.exceptions import ExecutionError + raise ExecutionError( f"Registry key '{query_name}' from file '{sql_file}' contains " f"whitespace, which is not allowed. Registry keys must not contain " diff --git a/row_query/core/sanitizer.py b/row_query/core/sanitizer.py index 93f5b6c..919e87d 100644 --- a/row_query/core/sanitizer.py +++ b/row_query/core/sanitizer.py @@ -53,9 +53,8 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: # Check for unterminated string if j >= n and (j == i + 1 or sql[j - 1] != "'"): from row_query.core.exceptions import SQLSanitizationError - raise SQLSanitizationError( - "Unterminated string literal detected in SQL" - ) + + raise SQLSanitizationError("Unterminated string literal detected in SQL") tokens.append(("string", sql[i:j])) last = j i = j @@ -75,9 +74,8 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: # Check for unterminated identifier if j >= n and (j == i + 1 or sql[j - 1] != '"'): from row_query.core.exceptions import SQLSanitizationError - raise SQLSanitizationError( - "Unterminated double-quoted identifier detected in SQL" - ) + + raise SQLSanitizationError("Unterminated double-quoted identifier detected in SQL") tokens.append(("identifier", sql[i:j])) last = j i = j @@ -97,6 +95,7 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: # Check for unterminated identifier if j >= n and (j == i + 1 or sql[j - 1] != "`"): from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( "Unterminated backtick-quoted identifier detected in SQL" )