diff --git a/shiba/__init__.py b/shiba/__init__.py index 67abc73..a4ce4a8 100644 --- a/shiba/__init__.py +++ b/shiba/__init__.py @@ -93,6 +93,16 @@ def transaction(self) -> AbstractContextManager[Database]: """Context manager transaccional. Ver :meth:`Database.transaction`.""" return self.db.transaction() + def raw( + self, + query: str, + params: object = None, + *, + many: bool = False, + ) -> list[dict[str, object]]: + """Escape hatch — ver :meth:`Database.raw`.""" + return self.db.raw(query, params, many=many) + __all__ = [ "ConnectionError", diff --git a/shiba/core/query_builder.py b/shiba/core/query_builder.py index 9a6eeb4..449acc7 100644 --- a/shiba/core/query_builder.py +++ b/shiba/core/query_builder.py @@ -2,15 +2,18 @@ Construye SQL con **placeholders parametrizados** para todos los valores y delega quoting de identificadores al :class:`~shiba.dialects.base.Dialect`. -Esto cierra el agujero de SQL injection que tenía la v1.x. """ from __future__ import annotations +import re +from collections.abc import Callable, Iterator from typing import TYPE_CHECKING, Any from shiba import error_codes from shiba.identifiers import validate_identifier, validate_operator +_JSON_PATH_RE = re.compile(r"^\$(\.[A-Za-z_][A-Za-z0-9_]*|\[[0-9]+\])+$") + if TYPE_CHECKING: from shiba.dialects.base import Dialect from shiba.dialects.mysql.driver import Database @@ -19,8 +22,13 @@ _VALID_JOIN_TYPES = frozenset({"JOIN", "LEFT JOIN", "RIGHT JOIN", "INNER JOIN", "CROSS JOIN"}) +# Cada cláusula WHERE acumulada es (connector, sql_fragment, params). +# `connector` se ignora en la primera; las siguientes se concatenan con él. +_WhereItem = tuple[str, str, list[Any]] + + class QueryBuilder: - """API fluida para consultas. Inmutable-ish: cada método retorna ``self``.""" + """API fluida para consultas. Cada método retorna ``self``.""" def __init__(self, db: Database, table_name: str, *, dialect: Dialect) -> None: self.db = db @@ -28,7 +36,9 @@ def __init__(self, db: Database, table_name: str, *, dialect: Dialect) -> None: self.table_name = validate_identifier(table_name, kind="table") self._selected: list[str] = [] self._joins: list[str] = [] - self._where: list[tuple[str, str, Any]] = [] + self._where: list[_WhereItem] = [] + self._group_by: list[str] = [] + self._having: list[_WhereItem] = [] self._order_by: list[tuple[str, str]] = [] self._limit: int | None = None self._offset: int | None = None @@ -97,50 +107,161 @@ def cross_join(self, table_name: str) -> QueryBuilder: return self # ------------------------------------------------------------------ - # WHERE + # WHERE — fragmentos atómicos + # ------------------------------------------------------------------ + + def _where_fragment( + self, column: str, operator: str, value: Any + ) -> tuple[str, list[Any]]: + """Devuelve ``(sql, params)`` para una condición simple.""" + validate_identifier(column, kind="column") + op = validate_operator(operator) + col_sql = self.dialect.quote_identifier(column) + if op in {"IN", "NOT IN"}: + if not isinstance(value, (list, tuple)) or not value: + raise error_codes.INVALID_QUERY_PARAMS.build( + f"{op} requiere lista/tupla no vacía." + ) + placeholders = ", ".join([self.dialect.placeholder] * len(value)) + return f"{col_sql} {op} ({placeholders})", list(value) + if op in {"IS", "IS NOT"} and value is None: + return f"{col_sql} {op} NULL", [] + return f"{col_sql} {op} {self.dialect.placeholder}", [value] + + def _push_where( + self, + connector: str, + column: str, + operator: str, + value: Any, + *, + target: list[_WhereItem] | None = None, + ) -> None: + sql, params = self._where_fragment(column, operator, value) + (target if target is not None else self._where).append((connector, sql, params)) + + # ------------------------------------------------------------------ + # WHERE — API pública # ------------------------------------------------------------------ def where(self, *args: Any) -> QueryBuilder: - """``where(col, val)`` o ``where(col, op, val)`` o ``where([[...], [...]])``.""" - # Forma con lista de condiciones. + """``where(col, val)``, ``where(col, op, val)`` o ``where([[...], [...]])``.""" if len(args) == 1 and isinstance(args[0], list): return self._where_many(args[0]) + column, operator, value = _unpack_condition(args) + self._push_where("AND", column, operator, value) + return self - if len(args) == 2: - column, value = args - operator = "=" - elif len(args) == 3: - column, operator, value = args - else: + def or_where(self, *args: Any) -> QueryBuilder: + column, operator, value = _unpack_condition(args) + self._push_where("OR", column, operator, value) + return self + + def where_in(self, column: str, values: list[Any] | tuple[Any, ...]) -> QueryBuilder: + self._push_where("AND", column, "IN", values) + return self + + def where_not_in( + self, column: str, values: list[Any] | tuple[Any, ...] + ) -> QueryBuilder: + self._push_where("AND", column, "NOT IN", values) + return self + + def where_null(self, column: str) -> QueryBuilder: + self._push_where("AND", column, "IS", None) + return self + + def where_not_null(self, column: str) -> QueryBuilder: + self._push_where("AND", column, "IS NOT", None) + return self + + def where_like(self, column: str, pattern: str) -> QueryBuilder: + self._push_where("AND", column, "LIKE", pattern) + return self + + def where_json( + self, + column: str, + path: str, + value: Any, + operator: str = "=", + ) -> QueryBuilder: + """Filtra por un campo dentro de una columna JSON. + + ``path`` se valida contra ``$.foo.bar`` / ``$[0]``; **no** se + parametriza (es estructura, no valor) pero sí se restringe a un + alfabeto seguro. + """ + validate_identifier(column, kind="column") + if not _JSON_PATH_RE.match(path): raise error_codes.INVALID_QUERY_PARAMS.build( - f"where() acepta 2 o 3 argumentos, recibió {len(args)}." + f"path JSON inválido: {path!r}. Usa $.foo o $[0]." ) + op = validate_operator(operator) + col_sql = self.dialect.quote_identifier(column) + sql = ( + f"JSON_UNQUOTE(JSON_EXTRACT({col_sql}, '{path}')) " + f"{op} {self.dialect.placeholder}" + ) + self._where.append(("AND", sql, [value])) + return self + def where_between(self, column: str, low: Any, high: Any) -> QueryBuilder: validate_identifier(column, kind="column") - op = validate_operator(operator) - self._where.append((column, op, value)) + col_sql = self.dialect.quote_identifier(column) + ph = self.dialect.placeholder + self._where.append( + ("AND", f"{col_sql} BETWEEN {ph} AND {ph}", [low, high]) + ) + return self + + def where_group(self, callback: Callable[[QueryBuilder], None]) -> QueryBuilder: + """Agrupa condiciones entre paréntesis. Útil para mezclar AND/OR. + + .. code-block:: python + + q.where("active", True).where_group( + lambda g: g.where("role", "admin").or_where("role", "owner") + ) + """ + sub = QueryBuilder(self.db, self.table_name, dialect=self.dialect) + callback(sub) + if not sub._where: + return self + group_sql, group_params = _compile_where_clause(sub._where, leading=False) + self._where.append(("AND", f"({group_sql})", group_params)) return self def _where_many(self, conditions: list[Any]) -> QueryBuilder: for cond in conditions: if not isinstance(cond, (list, tuple)): raise error_codes.INVALID_QUERY_PARAMS.build( - f"cada condición de where() debe ser list/tuple, recibió {type(cond).__name__}." - ) - if len(cond) == 2: - self.where(cond[0], cond[1]) - elif len(cond) == 3: - self.where(cond[0], cond[1], cond[2]) - else: - raise error_codes.INVALID_QUERY_PARAMS.build( - f"condición con {len(cond)} elementos; se esperaban 2 o 3." + f"cada condición debe ser list/tuple, recibió {type(cond).__name__}." ) + column, operator, value = _unpack_condition(tuple(cond)) + self._push_where("AND", column, operator, value) return self # ------------------------------------------------------------------ - # ORDER / LIMIT + # GROUP BY / HAVING / ORDER / LIMIT # ------------------------------------------------------------------ + def group_by(self, *columns: str) -> QueryBuilder: + if not columns: + raise error_codes.MISSING_REQUIRED_DATA.build( + "group_by() requiere al menos una columna." + ) + for col in columns: + validate_identifier(col, kind="column") + self._group_by.extend(columns) + return self + + def having(self, *args: Any) -> QueryBuilder: + column, operator, value = _unpack_condition(args) + sql, params = self._where_fragment(column, operator, value) + self._having.append(("AND", sql, params)) + return self + def order_by(self, column: str, direction: str = "ASC") -> QueryBuilder: validate_identifier(column, kind="column") d = direction.strip().upper() @@ -160,71 +281,178 @@ def offset(self, n: int) -> QueryBuilder: return self # ------------------------------------------------------------------ - # Compilación de WHERE + # Compilación # ------------------------------------------------------------------ def _compile_where(self) -> tuple[str, list[Any]]: if not self._where: return "", [] - parts: list[str] = [] - params: list[Any] = [] - for column, op, value in self._where: - col_sql = self.dialect.quote_identifier(column) - if op in {"IN", "NOT IN"}: - if not isinstance(value, (list, tuple)) or not value: - raise error_codes.INVALID_QUERY_PARAMS.build( - f"{op} requiere lista/tupla no vacía." - ) - placeholders = ", ".join([self.dialect.placeholder] * len(value)) - parts.append(f"{col_sql} {op} ({placeholders})") - params.extend(value) - elif op in {"IS", "IS NOT"} and value is None: - parts.append(f"{col_sql} {op} NULL") - else: - parts.append(f"{col_sql} {op} {self.dialect.placeholder}") - params.append(value) - return "WHERE " + " AND ".join(parts), params - - # ------------------------------------------------------------------ - # SELECT execution - # ------------------------------------------------------------------ - - def get(self) -> list[dict[str, Any]]: - select_clause = "*" - if self._selected: - select_clause = ", ".join(self.dialect.quote_identifier(c) for c in self._selected) - - joins = " ".join(self._joins) + sql, params = _compile_where_clause(self._where, leading=True) + return sql, params + + def _compile_having(self) -> str: + if not self._having: + return "" + sql, _ = _compile_where_clause(self._having, leading=False) + return "HAVING " + sql + + def _compile_select_clause(self) -> str: + if not self._selected: + return "*" + return ", ".join(self.dialect.quote_identifier(c) for c in self._selected) + + def _compile_tail(self) -> tuple[str, list[Any]]: + """Devuelve la SQL común WHERE+GROUP+HAVING+ORDER+LIMIT y sus params.""" where_sql, params = self._compile_where() - + having_sql = self._compile_having() + for _, _, p in self._having: + params.extend(p) + + group_sql = "" + if self._group_by: + group_sql = "GROUP BY " + ", ".join( + self.dialect.quote_identifier(c) for c in self._group_by + ) order_sql = "" if self._order_by: order_sql = "ORDER BY " + ", ".join( f"{self.dialect.quote_identifier(c)} {d}" for c, d in self._order_by ) limit_sql = self.dialect.render_limit(self._limit, self._offset) + tail = " ".join(p for p in [where_sql, group_sql, having_sql, order_sql, limit_sql] if p) + return tail, params + + # ------------------------------------------------------------------ + # Lectura + # ------------------------------------------------------------------ + def get(self) -> list[dict[str, Any]]: + select_clause = self._compile_select_clause() + joins = " ".join(self._joins) + tail, params = self._compile_tail() table = self.dialect.quote_identifier(self.table_name) - parts = [f"SELECT {select_clause} FROM {table}", joins, where_sql, order_sql, limit_sql] - query = " ".join(p for p in parts if p).strip() - return self.db.execute(query, tuple(params) if params else None) + query = " ".join( + p for p in [f"SELECT {select_clause} FROM {table}", joins, tail] if p + ) + return self.db.execute(query.strip(), tuple(params) if params else None) def first(self) -> dict[str, Any] | None: self._limit = 1 rows = self.get() return rows[0] if rows else None - def count(self, column: str = "*") -> int: + def find(self, pk_value: Any, *, pk: str = "id") -> dict[str, Any] | None: + return self.where(pk, pk_value).first() + + def exists(self) -> bool: + return self.count() > 0 + + def pluck(self, column: str) -> list[Any]: + validate_identifier(column, kind="column") + rows = self.select(column).get() + return [row[column] for row in rows] + + def _aggregate(self, fn: str, column: str) -> Any: col = "*" if column == "*" else self.dialect.quote_identifier(column) joins = " ".join(self._joins) - where_sql, params = self._compile_where() + tail, params = self._compile_tail() table = self.dialect.quote_identifier(self.table_name) - query = f"SELECT COUNT({col}) AS cnt FROM {table} {joins} {where_sql}".strip() - rows = self.db.execute(query, tuple(params) if params else None) - return int(rows[0]["cnt"]) if rows else 0 + query = " ".join( + p for p in [f"SELECT {fn}({col}) AS v FROM {table}", joins, tail] if p + ) + rows = self.db.execute(query.strip(), tuple(params) if params else None) + return rows[0]["v"] if rows else None + + def count(self, column: str = "*") -> int: + result = self._aggregate("COUNT", column) + return int(result) if result is not None else 0 + + def sum(self, column: str) -> Any: + return self._aggregate("SUM", column) + + def avg(self, column: str) -> Any: + return self._aggregate("AVG", column) + + def min(self, column: str) -> Any: + return self._aggregate("MIN", column) + + def max(self, column: str) -> Any: + return self._aggregate("MAX", column) # ------------------------------------------------------------------ - # INSERT / UPDATE / DELETE + # Paginación y streaming + # ------------------------------------------------------------------ + + def paginate(self, page: int = 1, per_page: int = 25) -> dict[str, Any]: + """Devuelve ``{page, per_page, total, last_page, data}``.""" + if page < 1 or per_page < 1: + raise error_codes.INVALID_QUERY_PARAMS.build( + "paginate() exige page>=1 y per_page>=1." + ) + total = QueryBuilder._clone_for_count(self).count() + last_page = max(1, (total + per_page - 1) // per_page) + self._limit = per_page + self._offset = (page - 1) * per_page + return { + "page": page, + "per_page": per_page, + "total": total, + "last_page": last_page, + "data": self.get(), + } + + @staticmethod + def _clone_for_count(src: QueryBuilder) -> QueryBuilder: + """Clon ligero que comparte WHERE/JOIN pero sin limit/offset/order.""" + clone = QueryBuilder(src.db, src.table_name, dialect=src.dialect) + clone._joins = list(src._joins) + clone._where = list(src._where) + clone._group_by = list(src._group_by) + clone._having = list(src._having) + return clone + + def chunk( + self, + size: int, + callback: Callable[[list[dict[str, Any]]], None], + *, + order_by_pk: str = "id", + ) -> None: + """Procesa la consulta en lotes de ``size`` filas. + + Pagina por ``OFFSET`` (suficiente para tablas medianas). Para + tablas muy grandes usar :meth:`iterate` con cursor keyset. + """ + if size < 1: + raise error_codes.INVALID_QUERY_PARAMS.build("chunk size debe ser >= 1.") + offset = 0 + while True: + clone = QueryBuilder._clone_for_count(self) + clone._order_by = list(self._order_by) or [(order_by_pk, "ASC")] + clone._limit = size + clone._offset = offset + batch = clone.get() + if not batch: + return + callback(batch) + if len(batch) < size: + return + offset += size + + def iterate( + self, + chunk_size: int = 1000, + *, + order_by_pk: str = "id", + ) -> Iterator[dict[str, Any]]: + """Generator que recorre toda la consulta por lotes.""" + batches: list[list[dict[str, Any]]] = [] + self.chunk(chunk_size, batches.append, order_by_pk=order_by_pk) + for batch in batches: + yield from batch + + # ------------------------------------------------------------------ + # INSERT / UPDATE / DELETE / UPSERT # ------------------------------------------------------------------ def insert(self, data: dict[str, Any] | list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -253,12 +481,12 @@ def _insert_many(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: if not first_keys: raise error_codes.MISSING_REQUIRED_DATA.build("insert(): filas vacías.") cols = [validate_identifier(k, kind="column") for k in first_keys] - # Forzamos que todas las filas tengan las mismas claves y en el mismo orden. values: list[tuple[Any, ...]] = [] for row in rows: if list(row.keys()) != first_keys: raise error_codes.INVALID_DATA_FORMAT.build( - "insert(): todas las filas deben tener las mismas claves en el mismo orden." + "insert(): todas las filas deben tener las mismas claves " + "en el mismo orden." ) values.append(tuple(row[k] for k in first_keys)) cols_sql = ", ".join(self.dialect.quote_identifier(c) for c in cols) @@ -267,6 +495,33 @@ def _insert_many(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: query = f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders})" return self.db.execute(query, values, many=True) + def upsert( + self, + data: dict[str, Any], + *, + update: list[str] | None = None, + ) -> list[dict[str, Any]]: + """INSERT con resolución de conflicto. + + En MySQL se emite ``ON DUPLICATE KEY UPDATE``. El parámetro + ``update`` indica qué columnas pisar (por defecto todas excepto + las que sean clave). El dialecto adapta la cláusula. + """ + if not data: + raise error_codes.MISSING_REQUIRED_DATA.build("upsert(): dict vacío.") + cols = [validate_identifier(k, kind="column") for k in data] + cols_sql = ", ".join(self.dialect.quote_identifier(c) for c in cols) + placeholders = ", ".join([self.dialect.placeholder] * len(cols)) + update_cols = update if update is not None else list(data.keys()) + for col in update_cols: + validate_identifier(col, kind="column") + update_sql = self.dialect.compile_upsert_update(update_cols) + table = self.dialect.quote_identifier(self.table_name) + query = ( + f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders}) {update_sql}" + ) + return self.db.execute(query, tuple(data.values())) + def update(self, data: dict[str, Any]) -> list[dict[str, Any]]: if not data: raise error_codes.MISSING_REQUIRED_DATA.build("update() requiere datos.") @@ -274,7 +529,9 @@ def update(self, data: dict[str, Any]) -> list[dict[str, Any]]: params: list[Any] = [] for col, val in data.items(): validate_identifier(col, kind="column") - set_parts.append(f"{self.dialect.quote_identifier(col)} = {self.dialect.placeholder}") + set_parts.append( + f"{self.dialect.quote_identifier(col)} = {self.dialect.placeholder}" + ) params.append(val) where_sql, where_params = self._compile_where() params.extend(where_params) @@ -291,3 +548,37 @@ def delete(self) -> list[dict[str, Any]]: table = self.dialect.quote_identifier(self.table_name) query = f"DELETE FROM {table} {where_sql}".strip() return self.db.execute(query, tuple(params)) + + def truncate(self) -> list[dict[str, Any]]: + """``TRUNCATE TABLE`` — borra todas las filas y resetea AUTO_INCREMENT.""" + table = self.dialect.quote_identifier(self.table_name) + return self.db.execute(f"TRUNCATE TABLE {table}") + + +# --------------------------------------------------------------------------- +# Helpers de módulo +# --------------------------------------------------------------------------- + +def _unpack_condition(args: tuple[Any, ...]) -> tuple[str, str, Any]: + if len(args) == 2: + return args[0], "=", args[1] + if len(args) == 3: + return args[0], args[1], args[2] + raise error_codes.INVALID_QUERY_PARAMS.build( + f"se esperaban 2 o 3 argumentos, llegaron {len(args)}." + ) + + +def _compile_where_clause( + items: list[_WhereItem], *, leading: bool +) -> tuple[str, list[Any]]: + parts: list[str] = [] + params: list[Any] = [] + for idx, (connector, sql, p) in enumerate(items): + if idx == 0: + parts.append(sql) + else: + parts.append(f"{connector} {sql}") + params.extend(p) + body = " ".join(parts) + return (f"WHERE {body}", params) if leading else (body, params) diff --git a/shiba/dialects/base.py b/shiba/dialects/base.py index 0cf5679..1cece2f 100644 --- a/shiba/dialects/base.py +++ b/shiba/dialects/base.py @@ -41,3 +41,11 @@ def render_limit(self, limit: int | None, offset: int | None) -> str: if offset is not None: parts.append(f"OFFSET {int(offset)}") return " ".join(parts) + + @abstractmethod + def compile_upsert_update(self, update_columns: list[str]) -> str: + """Cláusula de resolución de conflicto para ``upsert``. + + MySQL → ``ON DUPLICATE KEY UPDATE col = VALUES(col), ...`` + Postgres/SQLite → ``ON CONFLICT (...) DO UPDATE SET ...``. + """ diff --git a/shiba/dialects/mysql/dialect.py b/shiba/dialects/mysql/dialect.py index 2b7685d..91e237d 100644 --- a/shiba/dialects/mysql/dialect.py +++ b/shiba/dialects/mysql/dialect.py @@ -17,3 +17,9 @@ def quote_identifier(self, name: str) -> str: def map_type(self, declared: str) -> str: return _map_type(declared) + + def compile_upsert_update(self, update_columns: list[str]) -> str: + if not update_columns: + return "" + parts = [f"{_qi(c)} = VALUES({_qi(c)})" for c in update_columns] + return "ON DUPLICATE KEY UPDATE " + ", ".join(parts) diff --git a/shiba/dialects/mysql/driver.py b/shiba/dialects/mysql/driver.py index 8127038..62b83fc 100644 --- a/shiba/dialects/mysql/driver.py +++ b/shiba/dialects/mysql/driver.py @@ -192,6 +192,20 @@ def execute( # Alias retro-compatible con la API v1.x. execute_query = execute + def raw( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + """Escape hatch para SQL crudo — sin builder. + + El llamador es responsable de pasar **valores siempre como + parámetros**, nunca interpolados en ``query``. + """ + return self.execute(query, params, many=many) + def _rollback_silent(self) -> None: if self._connection is None or not self._connection.open: return diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index e1a3539..572f4a8 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -132,8 +132,8 @@ def test_delete_with_where(fake_db, dialect) -> None: def test_count(fake_db, dialect) -> None: - fake_db.result = [{"cnt": 42}] + fake_db.result = [{"v": 42}] n = QueryBuilder(fake_db, "users", dialect=dialect).where("active", True).count() assert n == 42 sql, _, _ = fake_db.last_call - assert sql.startswith("SELECT COUNT(*) AS cnt FROM `users`") + assert sql.startswith("SELECT COUNT(*) AS v FROM `users`") diff --git a/tests/test_query_builder_fase1.py b/tests/test_query_builder_fase1.py new file mode 100644 index 0000000..8d9a59c --- /dev/null +++ b/tests/test_query_builder_fase1.py @@ -0,0 +1,242 @@ +"""Cobertura de features añadidas en Fase 1. + +WHERE: or_where, where_in/null/between/like, where_group, where_json. +GROUP BY / HAVING. paginate / chunk / iterate. find / exists / pluck. +sum/avg/min/max. upsert. truncate. raw(). +""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.core.query_builder import QueryBuilder +from shiba.errors import QueryError, SchemaError + +# --------------------------------------------------------------------------- +# WHERE variants +# --------------------------------------------------------------------------- + +def test_or_where(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).where("a", 1).or_where("b", 2).get() + sql, params, _ = fake_db.last_call + assert "WHERE `a` = %s OR `b` = %s" in sql + assert params == (1, 2) + + +def test_where_in_alias(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_in("id", [1, 2]).get() + sql, params, _ = fake_db.last_call + assert "WHERE `id` IN (%s, %s)" in sql + assert params == (1, 2) + + +def test_where_null(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_null("deleted_at").get() + sql, _, _ = fake_db.last_call + assert "WHERE `deleted_at` IS NULL" in sql + + +def test_where_not_null(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_not_null("email").get() + sql, _, _ = fake_db.last_call + assert "WHERE `email` IS NOT NULL" in sql + + +def test_where_between(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_between("age", 18, 65).get() + sql, params, _ = fake_db.last_call + assert "WHERE `age` BETWEEN %s AND %s" in sql + assert params == (18, 65) + + +def test_where_like(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_like("name", "John%").get() + sql, params, _ = fake_db.last_call + assert "WHERE `name` LIKE %s" in sql + assert params == ("John%",) + + +def test_where_group_mixes_and_or(fake_db, dialect) -> None: + ( + QueryBuilder(fake_db, "u", dialect=dialect) + .where("active", True) + .where_group(lambda g: g.where("role", "admin").or_where("role", "owner")) + .get() + ) + sql, params, _ = fake_db.last_call + assert "WHERE `active` = %s AND (`role` = %s OR `role` = %s)" in sql + assert params == (True, "admin", "owner") + + +# --------------------------------------------------------------------------- +# JSON +# --------------------------------------------------------------------------- + +def test_where_json_extracts_path(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_json("settings", "$.theme", "dark").get() + sql, params, _ = fake_db.last_call + assert "JSON_UNQUOTE(JSON_EXTRACT(`settings`, '$.theme')) = %s" in sql + assert params == ("dark",) + + +def test_where_json_rejects_unsafe_path(fake_db, dialect) -> None: + with pytest.raises(QueryError) as ei: + QueryBuilder(fake_db, "u", dialect=dialect).where_json( + "settings", "$.theme'; DROP TABLE x; --", "dark" + ) + assert ei.value.code is error_codes.INVALID_QUERY_PARAMS + + +# --------------------------------------------------------------------------- +# GROUP BY / HAVING +# --------------------------------------------------------------------------- + +def test_group_by_and_having(fake_db, dialect) -> None: + fake_db.result = [{"v": 0}] + ( + QueryBuilder(fake_db, "orders", dialect=dialect) + .select("user_id") + .group_by("user_id") + .having("total", ">", 100) + .get() + ) + sql, params, _ = fake_db.last_call + assert "GROUP BY `user_id`" in sql + assert "HAVING `total` > %s" in sql + assert params == (100,) + + +# --------------------------------------------------------------------------- +# Aggregates +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "method,fn", + [("sum", "SUM"), ("avg", "AVG"), ("min", "MIN"), ("max", "MAX")], +) +def test_aggregates(fake_db, dialect, method: str, fn: str) -> None: + fake_db.result = [{"v": 7}] + qb = QueryBuilder(fake_db, "u", dialect=dialect) + result = getattr(qb, method)("age") + assert result == 7 + sql, _, _ = fake_db.last_call + assert f"SELECT {fn}(`age`) AS v FROM `users`".replace("`users`", "`u`") in sql + + +# --------------------------------------------------------------------------- +# Convenience reads +# --------------------------------------------------------------------------- + +def test_find_by_pk(fake_db, dialect) -> None: + fake_db.result = [{"id": 5, "name": "X"}] + row = QueryBuilder(fake_db, "u", dialect=dialect).find(5) + assert row == {"id": 5, "name": "X"} + sql, params, _ = fake_db.last_call + assert "WHERE `id` = %s" in sql + assert "LIMIT 1" in sql + assert params == (5,) + + +def test_exists_true_false(fake_db, dialect) -> None: + fake_db.result = [{"v": 3}] + assert QueryBuilder(fake_db, "u", dialect=dialect).exists() is True + fake_db.result = [{"v": 0}] + assert QueryBuilder(fake_db, "u", dialect=dialect).exists() is False + + +def test_pluck(fake_db, dialect) -> None: + fake_db.result = [{"name": "A"}, {"name": "B"}] + names = QueryBuilder(fake_db, "u", dialect=dialect).pluck("name") + assert names == ["A", "B"] + sql, _, _ = fake_db.last_call + assert sql.startswith("SELECT `name` FROM `u`") + + +# --------------------------------------------------------------------------- +# Paginate +# --------------------------------------------------------------------------- + +def test_paginate(fake_db, dialect) -> None: + # Primer execute = COUNT; segundo = data. + results = iter([[{"v": 53}], [{"id": i} for i in range(1, 26)]]) + + def execute(query, params=None, **kwargs): + fake_db.calls.append((query, params, kwargs.get("many", False))) + return next(results) + + fake_db.execute = execute # type: ignore[method-assign] + + page = QueryBuilder(fake_db, "u", dialect=dialect).paginate(page=2, per_page=25) + assert page["page"] == 2 + assert page["per_page"] == 25 + assert page["total"] == 53 + assert page["last_page"] == 3 + assert len(page["data"]) == 25 + # La segunda llamada debe tener LIMIT 25 OFFSET 25. + data_sql, _, _ = fake_db.calls[-1] + assert "LIMIT 25" in data_sql + assert "OFFSET 25" in data_sql + + +def test_paginate_rejects_zero(fake_db, dialect) -> None: + with pytest.raises(QueryError) as ei: + QueryBuilder(fake_db, "u", dialect=dialect).paginate(page=0, per_page=10) + assert ei.value.code is error_codes.INVALID_QUERY_PARAMS + + +# --------------------------------------------------------------------------- +# Chunk / iterate +# --------------------------------------------------------------------------- + +def test_chunk_iterates_until_empty(fake_db, dialect) -> None: + pages = iter( + [ + [{"id": 1}, {"id": 2}], + [{"id": 3}, {"id": 4}], + [{"id": 5}], + [], # nunca debería pedirse pero por seguridad + ] + ) + + def execute(query, params=None, **kwargs): + fake_db.calls.append((query, params, kwargs.get("many", False))) + return next(pages) + + fake_db.execute = execute # type: ignore[method-assign] + + collected: list[dict[str, int]] = [] + QueryBuilder(fake_db, "u", dialect=dialect).chunk(2, collected.extend) + assert [r["id"] for r in collected] == [1, 2, 3, 4, 5] + + +# --------------------------------------------------------------------------- +# UPSERT / TRUNCATE / RAW +# --------------------------------------------------------------------------- + +def test_upsert_emits_on_duplicate_key(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).upsert({"id": 1, "name": "X"}) + sql, params, _ = fake_db.last_call + assert sql == ( + "INSERT INTO `u` (`id`, `name`) VALUES (%s, %s) " + "ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `name` = VALUES(`name`)" + ) + assert params == (1, "X") + + +def test_upsert_with_explicit_update_columns(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).upsert( + {"id": 1, "name": "X", "age": 30}, update=["name", "age"] + ) + sql, _, _ = fake_db.last_call + assert "ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `age` = VALUES(`age`)" in sql + + +def test_truncate(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).truncate() + sql, _, _ = fake_db.last_call + assert sql == "TRUNCATE TABLE `u`" + + +def test_invalid_column_in_pluck(fake_db, dialect) -> None: + with pytest.raises(SchemaError): + QueryBuilder(fake_db, "u", dialect=dialect).pluck("name; DROP")