From 94a7b23cfa1ec0506683e3390e0dc1ac4dfd6735 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 21:59:22 +0000 Subject: [PATCH 1/2] Initial plan From 8502291ca4c1a31f76b91202ae88a01dd82b829b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 22:04:06 +0000 Subject: [PATCH 2/2] Move _resolve_sql to shared location and add error handling to transaction methods Co-authored-by: MaksimShevtsov <20194438+MaksimShevtsov@users.noreply.github.com> --- row_query/core/engine.py | 36 ++++------------ row_query/core/params.py | 43 ++++++++++++++++++- row_query/core/registry.py | 10 +++++ row_query/core/sanitizer.py | 80 +++++++++++++++++++++++++++++++++-- row_query/core/transaction.py | 58 +++++++++++++------------ 5 files changed, 169 insertions(+), 58 deletions(-) diff --git a/row_query/core/engine.py b/row_query/core/engine.py index 98f80f5..0cdb372 100644 --- a/row_query/core/engine.py +++ b/row_query/core/engine.py @@ -13,7 +13,7 @@ MultipleRowsError, ParameterBindingError, ) -from row_query.core.params import coerce_params, is_raw_sql, normalize_params +from row_query.core.params import coerce_params, normalize_params, resolve_sql from row_query.core.registry import SQLRegistry from row_query.core.sanitizer import SQLSanitizer from row_query.core.transaction import AsyncTransactionManager, TransactionManager @@ -21,24 +21,6 @@ T = TypeVar("T") -def _resolve_sql( - query: str, - registry: SQLRegistry, - sanitizer: SQLSanitizer | None = None, -) -> tuple[str, str]: - """Return ``(sql_text, label)`` for *query*. - - If *query* is an inline SQL string (contains whitespace) it is returned - after optional sanitization. Otherwise it is looked up in *registry* by - name (registry queries are trusted and never sanitized). *label* is used - in error messages. - """ - if is_raw_sql(query): - sql = sanitizer.sanitize(query) if sanitizer is not None else query - return sql, "" - return registry.get(query), query - - def _rows_to_dicts(cursor: Any) -> list[dict[str, Any]]: """Convert cursor results to list of dicts. @@ -134,7 +116,7 @@ def fetch_one( Returns None if zero rows match. Raises MultipleRowsError if more than one row matches. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -168,7 +150,7 @@ def fetch_all( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -194,7 +176,7 @@ def fetch_scalar( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -226,7 +208,7 @@ def execute( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -300,7 +282,7 @@ async def fetch_one( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -343,7 +325,7 @@ async def fetch_all( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -378,7 +360,7 @@ async def fetch_scalar( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) @@ -410,7 +392,7 @@ async def execute( *query* may be a registry key or an inline SQL string. *params* may be a dict, tuple/list, or scalar. """ - sql, label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) bound = coerce_params(params) diff --git a/row_query/core/params.py b/row_query/core/params.py index 3725537..7fe0d22 100644 --- a/row_query/core/params.py +++ b/row_query/core/params.py @@ -8,7 +8,11 @@ import re from functools import lru_cache -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from row_query.core.registry import SQLRegistry + from row_query.core.sanitizer import SQLSanitizer # Matches :name but not ::typecast and not inside words # Negative lookbehind for : (handles ::), \w (handles mid-word colons) @@ -61,6 +65,9 @@ def is_raw_sql(query: str) -> bool: Registry keys use dot-notation (e.g. ``users.get_by_id``) and never contain whitespace. Any SQL statement will contain at least one space. + + Note: Registry keys are validated during registration to ensure they do + not contain whitespace, preventing ambiguity. """ return any(c.isspace() for c in query) @@ -73,9 +80,43 @@ def coerce_params( * ``None`` / ``dict`` → returned as-is (named parameter binding). * ``tuple`` / ``list`` → converted to ``tuple`` (positional binding). * Any other scalar → wrapped in a single-element tuple. + + Note on parameter styles: + Registry queries use `:name` style parameters (converted to driver format). + Inline SQL can use either `:name` or `?`-style placeholders depending on + the database driver. When using inline SQL with positional parameters, + ensure compatibility with your target database (SQLite uses `?`, PostgreSQL + uses `$1`, etc.). """ if params is None or isinstance(params, dict): return params if isinstance(params, (tuple, list)): return tuple(params) return (params,) + + +def resolve_sql( + query: str, + registry: "SQLRegistry", + sanitizer: "SQLSanitizer | None" = None, +) -> tuple[str, str]: + """Return ``(sql_text, label)`` for *query*. + + If *query* is an inline SQL string (contains whitespace) it is returned + after optional sanitization. Otherwise it is looked up in *registry* by + name (registry queries are trusted and never sanitized). *label* is used + in error messages. + + Args: + query: Either a registry key (e.g. "users.get_by_id") or inline SQL. + registry: SQLRegistry instance for looking up named queries. + sanitizer: Optional SQLSanitizer applied only to inline SQL strings. + + Returns: + Tuple of (sql_text, label) where label is "" for inline SQL + or the registry key for named queries. + """ + if is_raw_sql(query): + sql = sanitizer.sanitize(query) if sanitizer is not None else query + return sql, "" + return registry.get(query), query diff --git a/row_query/core/registry.py b/row_query/core/registry.py index aa031f4..8133494 100644 --- a/row_query/core/registry.py +++ b/row_query/core/registry.py @@ -43,6 +43,16 @@ def _load(self) -> None: parts[-1] = parts[-1].removesuffix(".sql") query_name = ".".join(parts) + # Validate that query_name doesn't contain whitespace + # This prevents ambiguity with inline SQL detection + if any(c.isspace() for c in query_name): + from row_query.core.exceptions import ExecutionError + raise ExecutionError( + f"Registry key '{query_name}' from file '{sql_file}' contains " + f"whitespace, which is not allowed. Registry keys must not contain " + f"spaces, tabs, or newlines to avoid ambiguity with inline SQL." + ) + if query_name in self._queries: raise DuplicateQueryError( query_name, diff --git a/row_query/core/sanitizer.py b/row_query/core/sanitizer.py index 6517736..93f5b6c 100644 --- a/row_query/core/sanitizer.py +++ b/row_query/core/sanitizer.py @@ -21,10 +21,15 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: - """Split *sql* into ``('string', …)`` and ``('code', …)`` tokens. + """Split *sql* into ``('string', …)``, ``('identifier', …)``, and ``('code', …)`` tokens. String literals (single-quoted, with ``''`` escapes) are preserved as-is. + Identifiers (double-quoted for PostgreSQL/MySQL ANSI_QUOTES, backtick-quoted + for MySQL) are also preserved to avoid stripping comment-like syntax inside them. Everything else is a ``'code'`` token. + + Raises: + SQLSanitizationError: If an unterminated string literal or identifier is detected. """ tokens: list[tuple[str, str]] = [] i = 0 @@ -32,6 +37,7 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: last = 0 while i < n: + # Single-quoted string literal if sql[i] == "'": if i > last: tokens.append(("code", sql[last:i])) @@ -44,9 +50,59 @@ def _tokenize(sql: str) -> list[tuple[str, str]]: j += 1 # '' escape — continue else: j += 1 + # Check for unterminated string + if j >= n and (j == i + 1 or sql[j - 1] != "'"): + from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( + "Unterminated string literal detected in SQL" + ) tokens.append(("string", sql[i:j])) last = j i = j + # Double-quoted identifier (PostgreSQL, MySQL ANSI_QUOTES) + elif sql[i] == '"': + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == '"': + j += 1 + if j >= n or sql[j] != '"': + break # end of identifier + j += 1 # "" escape — continue + else: + j += 1 + # Check for unterminated identifier + if j >= n and (j == i + 1 or sql[j - 1] != '"'): + from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( + "Unterminated double-quoted identifier detected in SQL" + ) + tokens.append(("identifier", sql[i:j])) + last = j + i = j + # Backtick-quoted identifier (MySQL) + elif sql[i] == "`": + if i > last: + tokens.append(("code", sql[last:i])) + j = i + 1 + while j < n: + if sql[j] == "`": + j += 1 + if j >= n or sql[j] != "`": + break # end of identifier + j += 1 # `` escape — continue + else: + j += 1 + # Check for unterminated identifier + if j >= n and (j == i + 1 or sql[j - 1] != "`"): + from row_query.core.exceptions import SQLSanitizationError + raise SQLSanitizationError( + "Unterminated backtick-quoted identifier detected in SQL" + ) + tokens.append(("identifier", sql[i:j])) + last = j + i = j else: i += 1 @@ -88,10 +144,10 @@ def _strip_comments_in_code(code: str) -> str: def _strip_comments(sql: str) -> str: - """Remove SQL comments while preserving string literals.""" + """Remove SQL comments while preserving string literals and identifiers.""" parts: list[str] = [] for kind, content in _tokenize(sql): - if kind == "string": + if kind in ("string", "identifier"): parts.append(content) else: parts.append(_strip_comments_in_code(content)) @@ -101,7 +157,7 @@ def _strip_comments(sql: str) -> str: def _check_single_statement(sql: str) -> None: """Raise if *sql* contains a semicolon followed by non-whitespace content.""" for kind, content in _tokenize(sql): - if kind == "string": + if kind in ("string", "identifier"): continue for i, ch in enumerate(content): if ch == ";" and content[i + 1 :].strip(): @@ -133,6 +189,22 @@ class SQLSanitizer: Applied only to raw SQL passed directly to engine/transaction methods. Registry-loaded queries are always trusted and never sanitized. + **IMPORTANT SECURITY WARNING:** + This sanitizer does NOT protect against SQL injection if user-provided + data is concatenated directly into SQL strings. You MUST use parameterized + queries with placeholders (e.g., `?` or `:name`) to prevent SQL injection. + The sanitizer only provides defense-in-depth measures (comment stripping, + statement blocking, verb restrictions) but is NOT a substitute for proper + parameterization. + + Example of UNSAFE code: + # NEVER DO THIS - vulnerable to SQL injection + engine.fetch_all(f"SELECT * FROM users WHERE name = '{user_input}'") + + Example of SAFE code: + # ALWAYS USE THIS - parameterized query + engine.fetch_all("SELECT * FROM users WHERE name = ?", user_input) + Attributes: strip_comments: Strip ``--`` and ``/* */`` comments before execution. block_multiple_statements: Reject SQL that contains a statement- diff --git a/row_query/core/transaction.py b/row_query/core/transaction.py index b0c2bf8..62a0c4b 100644 --- a/row_query/core/transaction.py +++ b/row_query/core/transaction.py @@ -9,24 +9,12 @@ from enum import Enum from typing import Any -from row_query.core.exceptions import TransactionStateError -from row_query.core.params import coerce_params, is_raw_sql, normalize_params +from row_query.core.exceptions import ParameterBindingError, TransactionStateError +from row_query.core.params import coerce_params, normalize_params, resolve_sql from row_query.core.registry import SQLRegistry from row_query.core.sanitizer import SQLSanitizer -def _resolve_sql( - query: str, - registry: SQLRegistry, - sanitizer: SQLSanitizer | None = None, -) -> tuple[str, str]: - """Return ``(sql_text, label)`` for *query* (raw SQL or registry key).""" - if is_raw_sql(query): - sql = sanitizer.sanitize(query) if sanitizer is not None else query - return sql, "" - return registry.get(query), query - - class _TxState(Enum): IDLE = "idle" ACTIVE = "active" @@ -107,9 +95,12 @@ def execute( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) def fetch_one( @@ -123,9 +114,12 @@ def fetch_one( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e rows = _rows_to_dicts(cursor) if not rows: return None @@ -142,9 +136,12 @@ def fetch_all( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + try: + cursor = self._adapter.execute(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return _rows_to_dicts(cursor) def commit(self) -> None: @@ -232,9 +229,12 @@ async def execute( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + try: + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e return int(cursor.rowcount) async def fetch_one( @@ -248,9 +248,12 @@ async def fetch_one( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + try: + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return None columns = [desc[0] for desc in cursor.description] @@ -277,9 +280,12 @@ async def fetch_all( *params* may be a dict, tuple/list, or scalar. """ self._check_active() - sql, _label = _resolve_sql(query, self._registry, self._sanitizer) + sql, label = resolve_sql(query, self._registry, self._sanitizer) sql = normalize_params(sql, self._paramstyle) - cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + try: + cursor = await self._adapter.execute_async(self._connection, sql, coerce_params(params)) + except Exception as e: + raise ParameterBindingError(label, str(e)) from e if cursor.description is None: return [] columns = [desc[0] for desc in cursor.description]