From 3a0e864a1f6d56bf5c396099390bccfd2dc02dc6 Mon Sep 17 00:00:00 2001 From: Rodrigo Pino Date: Mon, 18 May 2026 01:32:57 -0400 Subject: [PATCH] =?UTF-8?q?feat(v2):=20rewrite=20completo=20=E2=80=94=20Fa?= =?UTF-8?q?se=200=E2=80=934=20con=20ORM=20y=20multi-dialecto?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolida en un solo commit el trabajo de los PRs #5–#9 (cuyas ramas quedaron divergentes tras varios force-pushes para fix de CI). El detalle por fase vive en los PRs cerrados. Fase 0 — Seguridad y estructura - Paquete `shiba/` modular con core/ agnóstico y dialects/. - SQL injection cerrada: valores parametrizados, identificadores validados y citados por dialecto. - Catálogo `SHIBA-XYZW` de error codes con mapper desde drivers. - Excepciones tipadas, transacciones explícitas, context manager. - `shibamysql/` shim deprecated. - Tooling: pyproject.toml, ruff, mypy --strict, pytest, CI. Fase 1 — Query builder rico - `or_where`, `where_in/not_in/null/not_null/like/between`, `where_group(callback)`, `where_json(col, path, value)`. - `group_by`, `having`, `find`, `exists`, `pluck`, agregados. - `paginate`, `chunk`, `iterate`. - `upsert(data, on=[...])`, `truncate`, `raw` escape hatch. Fase 1.5 — ORM híbrido tipado - `Model` con metaclass que lee anotaciones por MRO con `inspect.get_annotations(eval_str=True)`. - 12 subclases de `Field` (PrimaryKey, String, Json, DateTime, DecimalField, Enum, ForeignKey, ...) e inferencia automática. - `ModelQuery[T]` hidrata filas a instancias. - Conexión global (`set_default_connection`) o por clase (`__db__`). Fase 2 — PostgreSQL - Dialect con doble comilla, `JSONB`, `BYTEA`, `TIMESTAMP`, `GENERATED ALWAYS AS IDENTITY`, `ON CONFLICT DO UPDATE`. - Driver psycopg3 con import perezoso (dep opcional). - `shiba.connect("mysql://...")`/`postgres://...` factory. Fase 4 — Tests de integración - 7 tests sobre MySQL 8 y Postgres 16 con Testcontainers, marker `integration`, job CI separado. Stats - 114 unit tests + 7 integration verdes (39s end-to-end). - ruff y mypy --strict limpios sobre 25 archivos. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/ci.yml | 27 +- pyproject.toml | 9 +- shiba/__init__.py | 120 +++++-- shiba/core/query_builder.py | 433 +++++++++++++++++++++---- shiba/core/table_builder.py | 11 +- shiba/dialects/base.py | 22 ++ shiba/dialects/mysql/dialect.py | 10 + shiba/dialects/mysql/driver.py | 14 + shiba/dialects/postgres/__init__.py | 19 ++ shiba/dialects/postgres/dialect.py | 39 +++ shiba/dialects/postgres/driver.py | 238 ++++++++++++++ shiba/dialects/postgres/quoting.py | 13 + shiba/dialects/postgres/schema.py | 26 ++ shiba/error_codes.py | 4 +- shiba/identifiers.py | 4 +- shiba/orm/__init__.py | 40 +++ shiba/orm/fields.py | 305 +++++++++++++++++ shiba/orm/model.py | 314 ++++++++++++++++++ tests/integration/__init__.py | 0 tests/integration/conftest.py | 67 ++++ tests/integration/test_mysql_e2e.py | 93 ++++++ tests/integration/test_postgres_e2e.py | 59 ++++ tests/test_connect_factory.py | 81 +++++ tests/test_orm.py | 293 +++++++++++++++++ tests/test_postgres_dialect.py | 92 ++++++ tests/test_query_builder.py | 4 +- tests/test_query_builder_fase1.py | 242 ++++++++++++++ 27 files changed, 2470 insertions(+), 109 deletions(-) create mode 100644 shiba/dialects/postgres/__init__.py create mode 100644 shiba/dialects/postgres/dialect.py create mode 100644 shiba/dialects/postgres/driver.py create mode 100644 shiba/dialects/postgres/quoting.py create mode 100644 shiba/dialects/postgres/schema.py create mode 100644 shiba/orm/__init__.py create mode 100644 shiba/orm/fields.py create mode 100644 shiba/orm/model.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/conftest.py create mode 100644 tests/integration/test_mysql_e2e.py create mode 100644 tests/integration/test_postgres_e2e.py create mode 100644 tests/test_connect_factory.py create mode 100644 tests/test_orm.py create mode 100644 tests/test_postgres_dialect.py create mode 100644 tests/test_query_builder_fase1.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b265f4..10f531c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,8 +10,8 @@ concurrency: cancel-in-progress: true jobs: - lint-type-test: - name: ${{ matrix.python-version }} · lint · type · test + unit: + name: unit · ${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -37,5 +37,26 @@ jobs: - name: Mypy run: mypy - - name: Pytest + - name: Pytest (unit) run: pytest -q + + integration: + name: integration · MySQL + Postgres (testcontainers) + runs-on: ubuntu-latest + needs: unit + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + + - name: Install + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Pytest (integration) + run: pytest -m integration tests/integration -v diff --git a/pyproject.toml b/pyproject.toml index 358e2ce..15fc6be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,14 @@ classifiers = [ dependencies = ["pymysql>=1.1"] [project.optional-dependencies] +postgres = ["psycopg[binary]>=3.1"] dev = [ "pytest>=8", "pytest-cov>=5", "ruff>=0.6", "mypy>=1.10", + "psycopg[binary]>=3.1", + "testcontainers[mysql,postgres]>=4.7", ] [project.urls] @@ -49,7 +52,6 @@ select = ["E", "F", "I", "B", "UP", "SIM", "RUF", "ANN", "A"] ignore = [ "ANN401", # permitir Any explícito "A004", # sombreamos ConnectionError a propósito (parte de la API pública) - "UP038", # `isinstance(x, X | Y)` rompe en py3.10 sin from __future__ import ] [tool.ruff.lint.per-file-ignores] @@ -68,4 +70,7 @@ ignore_missing_imports = true [tool.pytest.ini_options] testpaths = ["tests"] -addopts = "-ra --strict-markers" +addopts = "-ra --strict-markers -m 'not integration'" +markers = [ + "integration: requiere Docker (MySQL/Postgres reales con Testcontainers)", +] diff --git a/shiba/__init__.py b/shiba/__init__.py index 67abc73..77bbd10 100644 --- a/shiba/__init__.py +++ b/shiba/__init__.py @@ -1,31 +1,27 @@ """Shiba — librería ligera para hablar con bases de datos relacionales. -Punto de entrada público: - .. code-block:: python - import shiba as s + import shiba - with s.ShibaConnection(host="localhost", port=3306, - user="u", password="p") as cx: - cx.create_database("my_db") - cx.use_database("my_db") - cx.create_table("users") \\ - .increments("id", primary_key=True) \\ - .string("name") \\ - .build() + # Forma 1 — DSN explícito (recomendado para multi-dialecto): + cx = shiba.connect("mysql://user:pass@localhost:3306/my_db") + cx = shiba.connect("postgres://user:pass@localhost:5432/my_db") - cx.table("users").insert({"name": "John"}) - rows = cx.table("users").where("name", "John").get() + # Forma 2 — construcción directa MySQL (legacy): + cx = shiba.ShibaConnection(host="localhost", port=3306, + user="u", password="p") """ from __future__ import annotations from contextlib import AbstractContextManager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse from shiba import error_codes from shiba.core.query_builder import QueryBuilder from shiba.core.table_builder import TableBuilder +from shiba.dialects.base import Dialect from shiba.dialects.mysql import Database, MySQLDialect from shiba.errors import ( ConnectionError, @@ -35,23 +31,43 @@ SchemaError, ShibaError, ) +from shiba.orm import Model, fields, set_default_connection if TYPE_CHECKING: from types import TracebackType class ShibaConnection: - """Fachada de alto nivel sobre un :class:`Database` MySQL.""" + """Fachada de alto nivel agnóstica de dialecto. + + Acepta dos formas de construcción: + + * Legacy MySQL: ``ShibaConnection(host, port, user, password)``. + * Inyectada: ``ShibaConnection(db=Database(...), dialect=Dialect(...))`` + (la usa :func:`connect`). + """ def __init__( self, - host: str, - port: int, - user: str, - password: str, + host: str | None = None, + port: int | None = None, + user: str | None = None, + password: str | None = None, *, database: str | None = None, + db: Any = None, + dialect: Dialect | None = None, ) -> None: + if db is not None and dialect is not None: + self.dialect: Dialect = dialect + self.db: Any = db + return + + if host is None or port is None or user is None or password is None: + error_codes.MISSING_REQUIRED_DATA.raise_( + "ShibaConnection requiere host/port/user/password o " + "db+dialect inyectados." + ) self.dialect = MySQLDialect() self.db = Database(host, port, user, password, database=database) @@ -77,10 +93,10 @@ def close(self) -> None: # API pública # ------------------------------------------------------------------ - def create_database(self, database: str) -> Database: + def create_database(self, database: str) -> Any: return self.db.create_database(database) - def use_database(self, database: str) -> Database: + def use_database(self, database: str) -> Any: return self.db.use_database(database) def create_table(self, table_name: str) -> TableBuilder: @@ -89,16 +105,69 @@ def create_table(self, table_name: str) -> TableBuilder: def table(self, table_name: str) -> QueryBuilder: return QueryBuilder(self.db, table_name, dialect=self.dialect) - def transaction(self) -> AbstractContextManager[Database]: - """Context manager transaccional. Ver :meth:`Database.transaction`.""" - return self.db.transaction() + def transaction(self) -> AbstractContextManager[Any]: + """Context manager transaccional.""" + cm: AbstractContextManager[Any] = self.db.transaction() + return cm + + def raw( + self, + query: str, + params: object = None, + *, + many: bool = False, + ) -> list[dict[str, object]]: + rows: list[dict[str, object]] = self.db.raw(query, params, many=many) + return rows + + +# --------------------------------------------------------------------------- +# Factory connect(dsn) +# --------------------------------------------------------------------------- + + +_DEFAULT_PORTS = {"mysql": 3306, "postgres": 5432, "postgresql": 5432} + + +def connect(dsn: str) -> ShibaConnection: + """Construye una :class:`ShibaConnection` desde un DSN tipo URL. + + Schemes soportados: + + * ``mysql://user:pass@host:port/dbname`` + * ``postgres://user:pass@host:port/dbname`` (alias: ``postgresql://``) + """ + parsed = urlparse(dsn) + scheme = parsed.scheme.lower() + if scheme not in _DEFAULT_PORTS: + error_codes.NOT_IMPLEMENTED.raise_( + f"DSN scheme '{scheme}' no soportado. Usa: {sorted(_DEFAULT_PORTS)}." + ) + host = parsed.hostname or "localhost" + port = parsed.port or _DEFAULT_PORTS[scheme] + user = parsed.username or "" + password = parsed.password or "" + database = parsed.path.lstrip("/") or None + + if scheme == "mysql": + db: Any = Database(host, port, user, password, database=database) + return ShibaConnection(db=db, dialect=MySQLDialect()) + + # postgres / postgresql + from shiba.dialects.postgres import PostgresDialect + from shiba.dialects.postgres.driver import Database as PgDatabase + + db = PgDatabase(host, port, user, password, database=database) + return ShibaConnection(db=db, dialect=PostgresDialect()) __all__ = [ "ConnectionError", "Database", + "Dialect", "IntegrityError", "MissingDataError", + "Model", "MySQLDialect", "QueryBuilder", "QueryError", @@ -106,5 +175,8 @@ def transaction(self) -> AbstractContextManager[Database]: "ShibaConnection", "ShibaError", "TableBuilder", + "connect", "error_codes", + "fields", + "set_default_connection", ] diff --git a/shiba/core/query_builder.py b/shiba/core/query_builder.py index 9a6eeb4..bd14d0d 100644 --- a/shiba/core/query_builder.py +++ b/shiba/core/query_builder.py @@ -2,15 +2,18 @@ Construye SQL con **placeholders parametrizados** para todos los valores y delega quoting de identificadores al :class:`~shiba.dialects.base.Dialect`. -Esto cierra el agujero de SQL injection que tenía la v1.x. """ from __future__ import annotations +import re +from collections.abc import Callable, Iterator from typing import TYPE_CHECKING, Any from shiba import error_codes from shiba.identifiers import validate_identifier, validate_operator +_JSON_PATH_RE = re.compile(r"^\$(\.[A-Za-z_][A-Za-z0-9_]*|\[[0-9]+\])+$") + if TYPE_CHECKING: from shiba.dialects.base import Dialect from shiba.dialects.mysql.driver import Database @@ -19,8 +22,13 @@ _VALID_JOIN_TYPES = frozenset({"JOIN", "LEFT JOIN", "RIGHT JOIN", "INNER JOIN", "CROSS JOIN"}) +# Cada cláusula WHERE acumulada es (connector, sql_fragment, params). +# `connector` se ignora en la primera; las siguientes se concatenan con él. +_WhereItem = tuple[str, str, list[Any]] + + class QueryBuilder: - """API fluida para consultas. Inmutable-ish: cada método retorna ``self``.""" + """API fluida para consultas. Cada método retorna ``self``.""" def __init__(self, db: Database, table_name: str, *, dialect: Dialect) -> None: self.db = db @@ -28,7 +36,9 @@ def __init__(self, db: Database, table_name: str, *, dialect: Dialect) -> None: self.table_name = validate_identifier(table_name, kind="table") self._selected: list[str] = [] self._joins: list[str] = [] - self._where: list[tuple[str, str, Any]] = [] + self._where: list[_WhereItem] = [] + self._group_by: list[str] = [] + self._having: list[_WhereItem] = [] self._order_by: list[tuple[str, str]] = [] self._limit: int | None = None self._offset: int | None = None @@ -97,50 +107,161 @@ def cross_join(self, table_name: str) -> QueryBuilder: return self # ------------------------------------------------------------------ - # WHERE + # WHERE — fragmentos atómicos + # ------------------------------------------------------------------ + + def _where_fragment( + self, column: str, operator: str, value: Any + ) -> tuple[str, list[Any]]: + """Devuelve ``(sql, params)`` para una condición simple.""" + validate_identifier(column, kind="column") + op = validate_operator(operator) + col_sql = self.dialect.quote_identifier(column) + if op in {"IN", "NOT IN"}: + if not isinstance(value, list | tuple) or not value: + raise error_codes.INVALID_QUERY_PARAMS.build( + f"{op} requiere lista/tupla no vacía." + ) + placeholders = ", ".join([self.dialect.placeholder] * len(value)) + return f"{col_sql} {op} ({placeholders})", list(value) + if op in {"IS", "IS NOT"} and value is None: + return f"{col_sql} {op} NULL", [] + return f"{col_sql} {op} {self.dialect.placeholder}", [value] + + def _push_where( + self, + connector: str, + column: str, + operator: str, + value: Any, + *, + target: list[_WhereItem] | None = None, + ) -> None: + sql, params = self._where_fragment(column, operator, value) + (target if target is not None else self._where).append((connector, sql, params)) + + # ------------------------------------------------------------------ + # WHERE — API pública # ------------------------------------------------------------------ def where(self, *args: Any) -> QueryBuilder: - """``where(col, val)`` o ``where(col, op, val)`` o ``where([[...], [...]])``.""" - # Forma con lista de condiciones. + """``where(col, val)``, ``where(col, op, val)`` o ``where([[...], [...]])``.""" if len(args) == 1 and isinstance(args[0], list): return self._where_many(args[0]) + column, operator, value = _unpack_condition(args) + self._push_where("AND", column, operator, value) + return self - if len(args) == 2: - column, value = args - operator = "=" - elif len(args) == 3: - column, operator, value = args - else: + def or_where(self, *args: Any) -> QueryBuilder: + column, operator, value = _unpack_condition(args) + self._push_where("OR", column, operator, value) + return self + + def where_in(self, column: str, values: list[Any] | tuple[Any, ...]) -> QueryBuilder: + self._push_where("AND", column, "IN", values) + return self + + def where_not_in( + self, column: str, values: list[Any] | tuple[Any, ...] + ) -> QueryBuilder: + self._push_where("AND", column, "NOT IN", values) + return self + + def where_null(self, column: str) -> QueryBuilder: + self._push_where("AND", column, "IS", None) + return self + + def where_not_null(self, column: str) -> QueryBuilder: + self._push_where("AND", column, "IS NOT", None) + return self + + def where_like(self, column: str, pattern: str) -> QueryBuilder: + self._push_where("AND", column, "LIKE", pattern) + return self + + def where_json( + self, + column: str, + path: str, + value: Any, + operator: str = "=", + ) -> QueryBuilder: + """Filtra por un campo dentro de una columna JSON. + + ``path`` se valida contra ``$.foo.bar`` / ``$[0]``; **no** se + parametriza (es estructura, no valor) pero sí se restringe a un + alfabeto seguro. + """ + validate_identifier(column, kind="column") + if not _JSON_PATH_RE.match(path): raise error_codes.INVALID_QUERY_PARAMS.build( - f"where() acepta 2 o 3 argumentos, recibió {len(args)}." + f"path JSON inválido: {path!r}. Usa $.foo o $[0]." ) + op = validate_operator(operator) + col_sql = self.dialect.quote_identifier(column) + sql = ( + f"JSON_UNQUOTE(JSON_EXTRACT({col_sql}, '{path}')) " + f"{op} {self.dialect.placeholder}" + ) + self._where.append(("AND", sql, [value])) + return self + def where_between(self, column: str, low: Any, high: Any) -> QueryBuilder: validate_identifier(column, kind="column") - op = validate_operator(operator) - self._where.append((column, op, value)) + col_sql = self.dialect.quote_identifier(column) + ph = self.dialect.placeholder + self._where.append( + ("AND", f"{col_sql} BETWEEN {ph} AND {ph}", [low, high]) + ) + return self + + def where_group(self, callback: Callable[[QueryBuilder], None]) -> QueryBuilder: + """Agrupa condiciones entre paréntesis. Útil para mezclar AND/OR. + + .. code-block:: python + + q.where("active", True).where_group( + lambda g: g.where("role", "admin").or_where("role", "owner") + ) + """ + sub = QueryBuilder(self.db, self.table_name, dialect=self.dialect) + callback(sub) + if not sub._where: + return self + group_sql, group_params = _compile_where_clause(sub._where, leading=False) + self._where.append(("AND", f"({group_sql})", group_params)) return self def _where_many(self, conditions: list[Any]) -> QueryBuilder: for cond in conditions: - if not isinstance(cond, (list, tuple)): + if not isinstance(cond, list | tuple): raise error_codes.INVALID_QUERY_PARAMS.build( - f"cada condición de where() debe ser list/tuple, recibió {type(cond).__name__}." - ) - if len(cond) == 2: - self.where(cond[0], cond[1]) - elif len(cond) == 3: - self.where(cond[0], cond[1], cond[2]) - else: - raise error_codes.INVALID_QUERY_PARAMS.build( - f"condición con {len(cond)} elementos; se esperaban 2 o 3." + f"cada condición debe ser list/tuple, recibió {type(cond).__name__}." ) + column, operator, value = _unpack_condition(tuple(cond)) + self._push_where("AND", column, operator, value) return self # ------------------------------------------------------------------ - # ORDER / LIMIT + # GROUP BY / HAVING / ORDER / LIMIT # ------------------------------------------------------------------ + def group_by(self, *columns: str) -> QueryBuilder: + if not columns: + raise error_codes.MISSING_REQUIRED_DATA.build( + "group_by() requiere al menos una columna." + ) + for col in columns: + validate_identifier(col, kind="column") + self._group_by.extend(columns) + return self + + def having(self, *args: Any) -> QueryBuilder: + column, operator, value = _unpack_condition(args) + sql, params = self._where_fragment(column, operator, value) + self._having.append(("AND", sql, params)) + return self + def order_by(self, column: str, direction: str = "ASC") -> QueryBuilder: validate_identifier(column, kind="column") d = direction.strip().upper() @@ -160,71 +281,178 @@ def offset(self, n: int) -> QueryBuilder: return self # ------------------------------------------------------------------ - # Compilación de WHERE + # Compilación # ------------------------------------------------------------------ def _compile_where(self) -> tuple[str, list[Any]]: if not self._where: return "", [] - parts: list[str] = [] - params: list[Any] = [] - for column, op, value in self._where: - col_sql = self.dialect.quote_identifier(column) - if op in {"IN", "NOT IN"}: - if not isinstance(value, (list, tuple)) or not value: - raise error_codes.INVALID_QUERY_PARAMS.build( - f"{op} requiere lista/tupla no vacía." - ) - placeholders = ", ".join([self.dialect.placeholder] * len(value)) - parts.append(f"{col_sql} {op} ({placeholders})") - params.extend(value) - elif op in {"IS", "IS NOT"} and value is None: - parts.append(f"{col_sql} {op} NULL") - else: - parts.append(f"{col_sql} {op} {self.dialect.placeholder}") - params.append(value) - return "WHERE " + " AND ".join(parts), params - - # ------------------------------------------------------------------ - # SELECT execution - # ------------------------------------------------------------------ - - def get(self) -> list[dict[str, Any]]: - select_clause = "*" - if self._selected: - select_clause = ", ".join(self.dialect.quote_identifier(c) for c in self._selected) - - joins = " ".join(self._joins) + sql, params = _compile_where_clause(self._where, leading=True) + return sql, params + + def _compile_having(self) -> str: + if not self._having: + return "" + sql, _ = _compile_where_clause(self._having, leading=False) + return "HAVING " + sql + + def _compile_select_clause(self) -> str: + if not self._selected: + return "*" + return ", ".join(self.dialect.quote_identifier(c) for c in self._selected) + + def _compile_tail(self) -> tuple[str, list[Any]]: + """Devuelve la SQL común WHERE+GROUP+HAVING+ORDER+LIMIT y sus params.""" where_sql, params = self._compile_where() - + having_sql = self._compile_having() + for _, _, p in self._having: + params.extend(p) + + group_sql = "" + if self._group_by: + group_sql = "GROUP BY " + ", ".join( + self.dialect.quote_identifier(c) for c in self._group_by + ) order_sql = "" if self._order_by: order_sql = "ORDER BY " + ", ".join( f"{self.dialect.quote_identifier(c)} {d}" for c, d in self._order_by ) limit_sql = self.dialect.render_limit(self._limit, self._offset) + tail = " ".join(p for p in [where_sql, group_sql, having_sql, order_sql, limit_sql] if p) + return tail, params + + # ------------------------------------------------------------------ + # Lectura + # ------------------------------------------------------------------ + def get(self) -> list[dict[str, Any]]: + select_clause = self._compile_select_clause() + joins = " ".join(self._joins) + tail, params = self._compile_tail() table = self.dialect.quote_identifier(self.table_name) - parts = [f"SELECT {select_clause} FROM {table}", joins, where_sql, order_sql, limit_sql] - query = " ".join(p for p in parts if p).strip() - return self.db.execute(query, tuple(params) if params else None) + query = " ".join( + p for p in [f"SELECT {select_clause} FROM {table}", joins, tail] if p + ) + return self.db.execute(query.strip(), tuple(params) if params else None) def first(self) -> dict[str, Any] | None: self._limit = 1 rows = self.get() return rows[0] if rows else None - def count(self, column: str = "*") -> int: + def find(self, pk_value: Any, *, pk: str = "id") -> dict[str, Any] | None: + return self.where(pk, pk_value).first() + + def exists(self) -> bool: + return self.count() > 0 + + def pluck(self, column: str) -> list[Any]: + validate_identifier(column, kind="column") + rows = self.select(column).get() + return [row[column] for row in rows] + + def _aggregate(self, fn: str, column: str) -> Any: col = "*" if column == "*" else self.dialect.quote_identifier(column) joins = " ".join(self._joins) - where_sql, params = self._compile_where() + tail, params = self._compile_tail() table = self.dialect.quote_identifier(self.table_name) - query = f"SELECT COUNT({col}) AS cnt FROM {table} {joins} {where_sql}".strip() - rows = self.db.execute(query, tuple(params) if params else None) - return int(rows[0]["cnt"]) if rows else 0 + query = " ".join( + p for p in [f"SELECT {fn}({col}) AS v FROM {table}", joins, tail] if p + ) + rows = self.db.execute(query.strip(), tuple(params) if params else None) + return rows[0]["v"] if rows else None + + def count(self, column: str = "*") -> int: + result = self._aggregate("COUNT", column) + return int(result) if result is not None else 0 + + def sum(self, column: str) -> Any: + return self._aggregate("SUM", column) + + def avg(self, column: str) -> Any: + return self._aggregate("AVG", column) + + def min(self, column: str) -> Any: + return self._aggregate("MIN", column) + + def max(self, column: str) -> Any: + return self._aggregate("MAX", column) # ------------------------------------------------------------------ - # INSERT / UPDATE / DELETE + # Paginación y streaming + # ------------------------------------------------------------------ + + def paginate(self, page: int = 1, per_page: int = 25) -> dict[str, Any]: + """Devuelve ``{page, per_page, total, last_page, data}``.""" + if page < 1 or per_page < 1: + raise error_codes.INVALID_QUERY_PARAMS.build( + "paginate() exige page>=1 y per_page>=1." + ) + total = QueryBuilder._clone_for_count(self).count() + last_page = max(1, (total + per_page - 1) // per_page) + self._limit = per_page + self._offset = (page - 1) * per_page + return { + "page": page, + "per_page": per_page, + "total": total, + "last_page": last_page, + "data": self.get(), + } + + @staticmethod + def _clone_for_count(src: QueryBuilder) -> QueryBuilder: + """Clon ligero que comparte WHERE/JOIN pero sin limit/offset/order.""" + clone = QueryBuilder(src.db, src.table_name, dialect=src.dialect) + clone._joins = list(src._joins) + clone._where = list(src._where) + clone._group_by = list(src._group_by) + clone._having = list(src._having) + return clone + + def chunk( + self, + size: int, + callback: Callable[[list[dict[str, Any]]], None], + *, + order_by_pk: str = "id", + ) -> None: + """Procesa la consulta en lotes de ``size`` filas. + + Pagina por ``OFFSET`` (suficiente para tablas medianas). Para + tablas muy grandes usar :meth:`iterate` con cursor keyset. + """ + if size < 1: + raise error_codes.INVALID_QUERY_PARAMS.build("chunk size debe ser >= 1.") + offset = 0 + while True: + clone = QueryBuilder._clone_for_count(self) + clone._order_by = list(self._order_by) or [(order_by_pk, "ASC")] + clone._limit = size + clone._offset = offset + batch = clone.get() + if not batch: + return + callback(batch) + if len(batch) < size: + return + offset += size + + def iterate( + self, + chunk_size: int = 1000, + *, + order_by_pk: str = "id", + ) -> Iterator[dict[str, Any]]: + """Generator que recorre toda la consulta por lotes.""" + batches: list[list[dict[str, Any]]] = [] + self.chunk(chunk_size, batches.append, order_by_pk=order_by_pk) + for batch in batches: + yield from batch + + # ------------------------------------------------------------------ + # INSERT / UPDATE / DELETE / UPSERT # ------------------------------------------------------------------ def insert(self, data: dict[str, Any] | list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -253,12 +481,12 @@ def _insert_many(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: if not first_keys: raise error_codes.MISSING_REQUIRED_DATA.build("insert(): filas vacías.") cols = [validate_identifier(k, kind="column") for k in first_keys] - # Forzamos que todas las filas tengan las mismas claves y en el mismo orden. values: list[tuple[Any, ...]] = [] for row in rows: if list(row.keys()) != first_keys: raise error_codes.INVALID_DATA_FORMAT.build( - "insert(): todas las filas deben tener las mismas claves en el mismo orden." + "insert(): todas las filas deben tener las mismas claves " + "en el mismo orden." ) values.append(tuple(row[k] for k in first_keys)) cols_sql = ", ".join(self.dialect.quote_identifier(c) for c in cols) @@ -267,6 +495,37 @@ def _insert_many(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: query = f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders})" return self.db.execute(query, values, many=True) + def upsert( + self, + data: dict[str, Any], + *, + update: list[str] | None = None, + on: list[str] | None = None, + ) -> list[dict[str, Any]]: + """INSERT con resolución de conflicto. + + :param data: columnas → valores. + :param update: columnas a pisar en conflicto (default todas). + :param on: columnas del conflicto. Requerido por Postgres, + opcional en MySQL (lo detecta por la PK). + """ + if not data: + raise error_codes.MISSING_REQUIRED_DATA.build("upsert(): dict vacío.") + cols = [validate_identifier(k, kind="column") for k in data] + cols_sql = ", ".join(self.dialect.quote_identifier(c) for c in cols) + placeholders = ", ".join([self.dialect.placeholder] * len(cols)) + update_cols = update if update is not None else list(data.keys()) + for col in update_cols: + validate_identifier(col, kind="column") + for col in on or []: + validate_identifier(col, kind="column") + update_sql = self.dialect.compile_upsert_update(update_cols, on) + table = self.dialect.quote_identifier(self.table_name) + query = ( + f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders}) {update_sql}" + ) + return self.db.execute(query, tuple(data.values())) + def update(self, data: dict[str, Any]) -> list[dict[str, Any]]: if not data: raise error_codes.MISSING_REQUIRED_DATA.build("update() requiere datos.") @@ -274,7 +533,9 @@ def update(self, data: dict[str, Any]) -> list[dict[str, Any]]: params: list[Any] = [] for col, val in data.items(): validate_identifier(col, kind="column") - set_parts.append(f"{self.dialect.quote_identifier(col)} = {self.dialect.placeholder}") + set_parts.append( + f"{self.dialect.quote_identifier(col)} = {self.dialect.placeholder}" + ) params.append(val) where_sql, where_params = self._compile_where() params.extend(where_params) @@ -291,3 +552,37 @@ def delete(self) -> list[dict[str, Any]]: table = self.dialect.quote_identifier(self.table_name) query = f"DELETE FROM {table} {where_sql}".strip() return self.db.execute(query, tuple(params)) + + def truncate(self) -> list[dict[str, Any]]: + """``TRUNCATE TABLE`` — borra todas las filas y resetea AUTO_INCREMENT.""" + table = self.dialect.quote_identifier(self.table_name) + return self.db.execute(f"TRUNCATE TABLE {table}") + + +# --------------------------------------------------------------------------- +# Helpers de módulo +# --------------------------------------------------------------------------- + +def _unpack_condition(args: tuple[Any, ...]) -> tuple[str, str, Any]: + if len(args) == 2: + return args[0], "=", args[1] + if len(args) == 3: + return args[0], args[1], args[2] + raise error_codes.INVALID_QUERY_PARAMS.build( + f"se esperaban 2 o 3 argumentos, llegaron {len(args)}." + ) + + +def _compile_where_clause( + items: list[_WhereItem], *, leading: bool +) -> tuple[str, list[Any]]: + parts: list[str] = [] + params: list[Any] = [] + for idx, (connector, sql, p) in enumerate(items): + if idx == 0: + parts.append(sql) + else: + parts.append(f"{connector} {sql}") + params.extend(p) + body = " ".join(parts) + return (f"WHERE {body}", params) if leading else (body, params) diff --git a/shiba/core/table_builder.py b/shiba/core/table_builder.py index b2ec6d3..442ad78 100644 --- a/shiba/core/table_builder.py +++ b/shiba/core/table_builder.py @@ -68,7 +68,7 @@ def default(self, value: str | int | float | bool | None) -> TableBuilder: return self._amend_last("DEFAULT NULL") if isinstance(value, bool): return self._amend_last(f"DEFAULT {1 if value else 0}") - if isinstance(value, (int, float)): + if isinstance(value, int | float): return self._amend_last(f"DEFAULT {value}") return self._amend_last("DEFAULT " + _escape_enum_choice(value)) @@ -107,10 +107,11 @@ def _col(self, column_name: str, sql_type: str) -> TableBuilder: def increments(self, column_name: str = "id", primary_key: bool = False) -> TableBuilder: validate_identifier(column_name, kind="column") - suffix = " PRIMARY KEY" if primary_key else "" - self._append_column( - f"{self.dialect.quote_identifier(column_name)} INT AUTO_INCREMENT{suffix}" - ) + col_quoted = self.dialect.quote_identifier(column_name) + if primary_key: + self._append_column(self.dialect.compile_auto_increment_pk(col_quoted)) + else: + self._append_column(f"{col_quoted} INT AUTO_INCREMENT") return self def integer(self, column_name: str, length: int | None = None) -> TableBuilder: diff --git a/shiba/dialects/base.py b/shiba/dialects/base.py index 0cf5679..eca1ffa 100644 --- a/shiba/dialects/base.py +++ b/shiba/dialects/base.py @@ -41,3 +41,25 @@ def render_limit(self, limit: int | None, offset: int | None) -> str: if offset is not None: parts.append(f"OFFSET {int(offset)}") return " ".join(parts) + + @abstractmethod + def compile_upsert_update( + self, + update_columns: list[str], + conflict_columns: list[str] | None = None, + ) -> str: + """Cláusula de resolución de conflicto para ``upsert``. + + MySQL → ``ON DUPLICATE KEY UPDATE col = VALUES(col), ...`` + (``conflict_columns`` se ignora; lo detecta por la PK). + Postgres/SQLite → ``ON CONFLICT (col, ...) DO UPDATE SET col = EXCLUDED.col`` + (``conflict_columns`` obligatorio). + """ + + def compile_auto_increment_pk(self, column_quoted: str) -> str: + """Declaración inline de PK auto-incremental. + + Default MySQL-ish. Postgres lo override con + ``GENERATED ALWAYS AS IDENTITY``. + """ + return f"{column_quoted} INT AUTO_INCREMENT PRIMARY KEY" diff --git a/shiba/dialects/mysql/dialect.py b/shiba/dialects/mysql/dialect.py index 2b7685d..595f40e 100644 --- a/shiba/dialects/mysql/dialect.py +++ b/shiba/dialects/mysql/dialect.py @@ -17,3 +17,13 @@ def quote_identifier(self, name: str) -> str: def map_type(self, declared: str) -> str: return _map_type(declared) + + def compile_upsert_update( + self, + update_columns: list[str], + conflict_columns: list[str] | None = None, + ) -> str: + if not update_columns: + return "" + parts = [f"{_qi(c)} = VALUES({_qi(c)})" for c in update_columns] + return "ON DUPLICATE KEY UPDATE " + ", ".join(parts) diff --git a/shiba/dialects/mysql/driver.py b/shiba/dialects/mysql/driver.py index 8127038..62b83fc 100644 --- a/shiba/dialects/mysql/driver.py +++ b/shiba/dialects/mysql/driver.py @@ -192,6 +192,20 @@ def execute( # Alias retro-compatible con la API v1.x. execute_query = execute + def raw( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + """Escape hatch para SQL crudo — sin builder. + + El llamador es responsable de pasar **valores siempre como + parámetros**, nunca interpolados en ``query``. + """ + return self.execute(query, params, many=many) + def _rollback_silent(self) -> None: if self._connection is None or not self._connection.open: return diff --git a/shiba/dialects/postgres/__init__.py b/shiba/dialects/postgres/__init__.py new file mode 100644 index 0000000..8507025 --- /dev/null +++ b/shiba/dialects/postgres/__init__.py @@ -0,0 +1,19 @@ +"""Dialecto PostgreSQL.""" +from typing import Any + +from shiba.dialects.postgres.dialect import PostgresDialect + +__all__ = ["PostgresDialect"] + + +def _import_driver() -> Any: + """Import perezoso de ``Database`` (requiere ``psycopg``).""" + from shiba.dialects.postgres.driver import Database + + return Database + + +def __getattr__(name: str) -> Any: + if name == "Database": + return _import_driver() + raise AttributeError(name) diff --git a/shiba/dialects/postgres/dialect.py b/shiba/dialects/postgres/dialect.py new file mode 100644 index 0000000..460fa21 --- /dev/null +++ b/shiba/dialects/postgres/dialect.py @@ -0,0 +1,39 @@ +"""Implementación :class:`~shiba.dialects.base.Dialect` para PostgreSQL.""" +from __future__ import annotations + +from shiba import error_codes +from shiba.dialects.base import Dialect +from shiba.dialects.postgres.quoting import quote_identifier as _qi +from shiba.dialects.postgres.schema import map_type as _map_type + + +class PostgresDialect(Dialect): + """Doble comillas, placeholder ``%s`` (psycopg), ``ON CONFLICT``.""" + + name = "postgres" + placeholder = "%s" + + def quote_identifier(self, name: str) -> str: + return _qi(name) + + def map_type(self, declared: str) -> str: + return _map_type(declared) + + def compile_auto_increment_pk(self, column_quoted: str) -> str: + return f"{column_quoted} INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY" + + def compile_upsert_update( + self, + update_columns: list[str], + conflict_columns: list[str] | None = None, + ) -> str: + if not conflict_columns: + error_codes.MISSING_REQUIRED_DATA.raise_( + "upsert() en Postgres requiere el parámetro `on=[col, ...]` " + "para identificar el conflicto." + ) + target = ", ".join(_qi(c) for c in conflict_columns) + if not update_columns: + return f"ON CONFLICT ({target}) DO NOTHING" + sets = ", ".join(f"{_qi(c)} = EXCLUDED.{_qi(c)}" for c in update_columns) + return f"ON CONFLICT ({target}) DO UPDATE SET {sets}" diff --git a/shiba/dialects/postgres/driver.py b/shiba/dialects/postgres/driver.py new file mode 100644 index 0000000..c710106 --- /dev/null +++ b/shiba/dialects/postgres/driver.py @@ -0,0 +1,238 @@ +"""Wrapper sobre :mod:`psycopg` (v3) con la interfaz Shiba ``Database``.""" +from __future__ import annotations + +import logging +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +from shiba import error_codes +from shiba.dialects.postgres.quoting import quote_identifier +from shiba.error_codes import from_driver_exception + +if TYPE_CHECKING: + from types import TracebackType + +try: + import psycopg + from psycopg.rows import dict_row +except ImportError: # pragma: no cover - dep opcional + psycopg = None # type: ignore[assignment] + dict_row = None # type: ignore[assignment] + + +def _require_psycopg() -> None: + if psycopg is None: + raise ImportError( + "El dialecto Postgres requiere `psycopg[binary]`. " + "Instala: pip install 'shiba_mysql[postgres]' o psycopg[binary]" + ) + + +logger = logging.getLogger("shiba.postgres") + + +class Database: + """Conexión Postgres con la misma API que :class:`shiba.Database` de MySQL.""" + + def __init__( + self, + host: str, + port: int, + user: str, + password: str, + *, + database: str | None = None, + autoconnect: bool = True, + ) -> None: + _require_psycopg() + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self._connection: Any = None + self._in_transaction: bool = False + if autoconnect: + self.connect() + + # ------------------------------------------------------------------ + # Ciclo de vida + # ------------------------------------------------------------------ + + def connect(self) -> None: + if self._connection is not None and not self._connection.closed: + return + try: + self._connection = psycopg.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + dbname=self.database, + autocommit=True, + row_factory=dict_row, + ) + except psycopg.OperationalError as exc: + code = from_driver_exception(exc) + raise code.build( + f"No se pudo conectar a {self.host}:{self.port}: {exc}", + details={"host": self.host, "port": self.port}, + ) from exc + + def close(self) -> None: + if self._connection is not None and not self._connection.closed: + self._connection.close() + self._connection = None + self._in_transaction = False + + def __enter__(self) -> Database: + self.connect() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + @property + def _conn(self) -> Any: + if self._connection is None or self._connection.closed: + raise error_codes.CONNECTION_NOT_OPEN.build( + "La conexión Postgres no está abierta." + ) + return self._connection + + # ------------------------------------------------------------------ + # Ejecución + # ------------------------------------------------------------------ + + def execute( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + if not query or not isinstance(query, str): + raise error_codes.EMPTY_QUERY.build( + "Se intentó ejecutar una query vacía.", + query=str(query), + ) + conn = self._conn + try: + with conn.cursor() as cursor: + if params is None: + cursor.execute(query) + elif many: + cursor.executemany(query, params) + else: + cursor.execute(query, params) + try: + rows: list[dict[str, Any]] = list(cursor.fetchall()) + except psycopg.ProgrammingError: + rows = [] + if not self._in_transaction: + conn.commit() + return rows + except psycopg.errors.IntegrityError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Violación de integridad: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except psycopg.errors.SyntaxError as exc: + self._rollback_silent() + raise error_codes.QUERY_SYNTAX_ERROR.build( + f"Error de sintaxis: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except psycopg.OperationalError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error operacional: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except psycopg.Error as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error de driver: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + + execute_query = execute + + def raw( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + return self.execute(query, params, many=many) + + def _rollback_silent(self) -> None: + if self._connection is None or self._connection.closed: + return + try: + self._connection.rollback() + except psycopg.Error: # pragma: no cover - best effort + logger.warning("rollback failed", exc_info=True) + + # ------------------------------------------------------------------ + # Transacciones + # ------------------------------------------------------------------ + + @contextmanager + def transaction(self) -> Iterator[Database]: + if self._in_transaction: + raise error_codes.TRANSACTION_ALREADY_ACTIVE.build() + conn = self._conn + # En psycopg3 autocommit=True implica BEGIN explícito para abrir tx. + conn.execute("BEGIN") + self._in_transaction = True + try: + yield self + except BaseException: + self._rollback_silent() + raise + else: + conn.commit() + finally: + self._in_transaction = False + + # ------------------------------------------------------------------ + # DDL conveniencia + # ------------------------------------------------------------------ + + def create_database(self, name: str) -> Database: + try: + self.execute(f"CREATE DATABASE {quote_identifier(name)}") + except Exception as exc: # pragma: no cover - mensajes varían + if "already exists" not in str(exc): + raise + self.database = name + return self + + def use_database(self, name: str) -> Database: + """En Postgres no hay ``USE``. Se cierra y reconecta a la nueva DB.""" + self.close() + self.database = name + self.connect() + return self + + selected_database = use_database diff --git a/shiba/dialects/postgres/quoting.py b/shiba/dialects/postgres/quoting.py new file mode 100644 index 0000000..960e17b --- /dev/null +++ b/shiba/dialects/postgres/quoting.py @@ -0,0 +1,13 @@ +"""Quoting de identificadores Postgres (comillas dobles).""" +from __future__ import annotations + +from shiba.identifiers import validate_identifier + + +def quote_identifier(name: str) -> str: + """Valida y cita el identificador con comillas dobles. + + Soporta ``schema.table``, citando cada parte. + """ + validate_identifier(name) + return ".".join(f'"{part}"' for part in name.split(".")) diff --git a/shiba/dialects/postgres/schema.py b/shiba/dialects/postgres/schema.py new file mode 100644 index 0000000..4762b7a --- /dev/null +++ b/shiba/dialects/postgres/schema.py @@ -0,0 +1,26 @@ +"""Mapeo de tipos canónicos Shiba a SQL Postgres. + +El ``TableBuilder`` emite sintaxis estilo MySQL (``DATETIME``, +``BOOLEAN``, ``JSON``, ``BLOB``, etc.). Este módulo los traduce a la +forma idiomática de Postgres. +""" +from __future__ import annotations + +_MAPPING: dict[str, str] = { + "DATETIME": "TIMESTAMP", + "BLOB": "BYTEA", + "JSON": "JSONB", + "DOUBLE": "DOUBLE PRECISION", +} + + +def map_type(declared: str) -> str: + upper = declared.upper() + # Tipos con parámetros: ``VARCHAR(50)``, ``DECIMAL(10,2)``, etc. + if "(" in upper: + base, params = upper.split("(", 1) + base = base.strip() + if base in _MAPPING: + return f"{_MAPPING[base]}({params}" + return declared + return _MAPPING.get(upper, declared) diff --git a/shiba/error_codes.py b/shiba/error_codes.py index 555619d..ef5198d 100644 --- a/shiba/error_codes.py +++ b/shiba/error_codes.py @@ -32,7 +32,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from shiba.errors import ( ConnectionError, @@ -64,7 +64,7 @@ def build(self, message: str | None = None, **kwargs: Any) -> ShibaError: msg = message or self.default_message return self.exception_class(msg, code=self, **kwargs) - def raise_(self, message: str | None = None, **kwargs: Any) -> None: + def raise_(self, message: str | None = None, **kwargs: Any) -> NoReturn: """Lanza la excepción asociada con este código. Para ``QueryError`` y descendientes acepta ``query=``, ``params=``, diff --git a/shiba/identifiers.py b/shiba/identifiers.py index 964b65d..3f2083d 100644 --- a/shiba/identifiers.py +++ b/shiba/identifiers.py @@ -55,8 +55,8 @@ def validate_identifier(name: str, *, kind: str = "identifier") -> str: def validate_operator(op: str) -> str: """Acepta el operador o lanza :class:`SchemaError`.""" - if not isinstance(op, str): # defensa para callers no tipados - error_codes.INVALID_OPERATOR.raise_( # type: ignore[unreachable] + if not isinstance(op, str): # defensa para callers no tipados # type: ignore[unreachable] + error_codes.INVALID_OPERATOR.raise_( f"operador no string: {op!r}", details={"value": repr(op)}, ) diff --git a/shiba/orm/__init__.py b/shiba/orm/__init__.py new file mode 100644 index 0000000..ecaf73a --- /dev/null +++ b/shiba/orm/__init__.py @@ -0,0 +1,40 @@ +"""ORM tipado para Shiba. + +Uso mínimo: + +.. code-block:: python + + import shiba + from shiba.orm import Model, fields + + class User(Model): + __table__ = "users" + + id: int = fields.PrimaryKey() + name: str + email: str = fields.String(unique=True) + age: int | None = None + + shiba.set_default_connection(cx) + + User.create_table() + user = User(name="John", email="j@x.com", age=30) + user.save() + User.find(1) + User.where("age", ">", 18).get() +""" +from shiba.orm import fields +from shiba.orm.model import ( + Model, + ModelQuery, + get_default_connection, + set_default_connection, +) + +__all__ = [ + "Model", + "ModelQuery", + "fields", + "get_default_connection", + "set_default_connection", +] diff --git a/shiba/orm/fields.py b/shiba/orm/fields.py new file mode 100644 index 0000000..b1dfb34 --- /dev/null +++ b/shiba/orm/fields.py @@ -0,0 +1,305 @@ +"""Field descriptors para modelos Shiba. + +Los modelos del ORM combinan anotaciones Python (que dan el tipo) con +instancias de :class:`Field` (que dan metadata extra: unique, default, +auto-increment, etc.). Cuando un atributo sólo tiene anotación, el +:class:`Field` se infiere automáticamente con defaults razonables. +""" +from __future__ import annotations + +import json +import types +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import date, datetime +from decimal import Decimal +from typing import Any, Union, get_args, get_origin + +from shiba import error_codes + +_UNSET: Any = object() +"""Sentinel para distinguir 'no se pasó valor' de 'None explícito'.""" + + +@dataclass +class Field: + """Descriptor de columna SQL. + + Casi nunca se instancia directamente — se usan las subclases + semánticas (:class:`String`, :class:`Integer`, etc.) y la inferencia + desde anotaciones. + """ + + sql_type: str = "VARCHAR(255)" + nullable: bool = False + primary_key: bool = False + unique: bool = False + auto_increment: bool = False + default: Any = _UNSET + default_factory: Callable[[], Any] | None = None + column_name: str | None = None + foreign_key: tuple[str, str] | None = None + json: bool = False + enum_choices: tuple[str, ...] | None = None + indexed: bool = False + + # --- Conversión Python <-> DB -------------------------------------- + + def to_python(self, raw: Any) -> Any: + """Decodifica el valor que llega desde la fila SQL.""" + if raw is None: + return None + if self.json and isinstance(raw, str | bytes): + return json.loads(raw) + return raw + + def to_db(self, value: Any) -> Any: + """Codifica el valor antes de enviarlo a la base de datos.""" + if value is None: + return None + if self.json and not isinstance(value, str | bytes): + return json.dumps(value, default=str) + return value + + # --- Defaults ------------------------------------------------------- + + def has_default(self) -> bool: + return self.default is not _UNSET or self.default_factory is not None + + def get_default(self) -> Any: + if self.default_factory is not None: + return self.default_factory() + if self.default is _UNSET: + return None + return self.default + + # --- DDL ------------------------------------------------------------ + + def apply_to_table_builder(self, tb: Any, column_name: str) -> None: + """Replica este campo en el ``TableBuilder``.""" + col = self.column_name or column_name + # Tipo + if self.primary_key and self.auto_increment and self.sql_type == "INT": + tb.increments(col, primary_key=True) + elif self.enum_choices is not None: + tb.enum(col, list(self.enum_choices)) + else: + # Reusar map de tipos via raw type string. + type_lower = self.sql_type.lower() + if type_lower.startswith("varchar"): + length = int(self.sql_type[8:-1]) if "(" in self.sql_type else 255 + tb.string(col, length) + elif type_lower == "text": + tb.text(col) + elif type_lower.startswith("int"): + tb.integer(col) + elif type_lower == "bigint": + tb.big_integer(col) + elif type_lower == "smallint": + tb.small_integer(col) + elif type_lower == "tinyint": + tb.tiny_integer(col) + elif type_lower == "boolean": + tb.boolean(col) + elif type_lower == "json": + tb.json(col) + elif type_lower == "datetime": + tb.datetime(col) + elif type_lower == "date": + tb.date(col) + elif type_lower == "time": + tb.time(col) + elif type_lower == "timestamp": + tb.timestamp(col) + elif type_lower.startswith("decimal"): + tb.decimal(col) + elif type_lower.startswith("float") or type_lower == "double": + tb.floats(col) + elif type_lower == "blob": + tb.binary(col) + else: + error_codes.UNSUPPORTED_TYPE.raise_( + f"tipo {self.sql_type!r} no soportado por apply_to_table_builder." + ) + # Constraints adicionales — añadidas al último column declared. + if self.primary_key and not self.auto_increment: + tb.primary() + if self.unique: + tb.unique() + if self.nullable: + tb.nullable() + elif not self.primary_key: + tb.not_nullable() + if self.default is not _UNSET and self.default_factory is None: + tb.default(self.default) + if self.foreign_key is not None: + ftable, fcolumn = self.foreign_key + tb.foreign(f"fk_{col}_{ftable}", ftable, fcolumn) + + +# --------------------------------------------------------------------------- +# Subclases semánticas +# --------------------------------------------------------------------------- + + +@dataclass +class PrimaryKey(Field): + sql_type: str = "INT" + primary_key: bool = True + auto_increment: bool = True + + +@dataclass +class String(Field): + max_length: int = 255 + + def __post_init__(self) -> None: + self.sql_type = f"VARCHAR({self.max_length})" + + +@dataclass +class Text(Field): + sql_type: str = "TEXT" + + +@dataclass +class Integer(Field): + sql_type: str = "INT" + + +@dataclass +class BigInteger(Field): + sql_type: str = "BIGINT" + + +@dataclass +class Boolean(Field): + sql_type: str = "BOOLEAN" + + def to_python(self, raw: Any) -> Any: + if raw is None: + return None + return bool(raw) + + +@dataclass +class FloatField(Field): + sql_type: str = "FLOAT" + + +@dataclass +class DecimalField(Field): + sql_type: str = "DECIMAL" + + def to_python(self, raw: Any) -> Any: + if raw is None: + return None + return Decimal(str(raw)) + + +@dataclass +class DateTime(Field): + sql_type: str = "DATETIME" + default_now: bool = False + + def __post_init__(self) -> None: + if self.default_now and self.default_factory is None: + self.default_factory = datetime.now + + def to_python(self, raw: Any) -> Any: + if raw is None or isinstance(raw, datetime): + return raw + return datetime.fromisoformat(str(raw)) + + +@dataclass +class DateField(Field): + sql_type: str = "DATE" + + def to_python(self, raw: Any) -> Any: + if raw is None or isinstance(raw, date): + return raw + return date.fromisoformat(str(raw)) + + +@dataclass +class Json(Field): + sql_type: str = "JSON" + json: bool = True + + +@dataclass +class Enum(Field): + choices: tuple[str, ...] = field(default_factory=tuple) + + def __post_init__(self) -> None: + if not self.choices: + error_codes.MISSING_REQUIRED_DATA.raise_( + "Enum requiere choices." + ) + self.sql_type = "ENUM" + self.enum_choices = self.choices + + +@dataclass +class ForeignKey(Field): + """``ForeignKey(to='users', column='id')``.""" + + to: str = "" + column: str = "id" + sql_type: str = "INT" + + def __post_init__(self) -> None: + if not self.to: + error_codes.MISSING_REQUIRED_DATA.raise_( + "ForeignKey requiere `to` (tabla destino)." + ) + self.foreign_key = (self.to, self.column) + + +# --------------------------------------------------------------------------- +# Inferencia desde anotaciones +# --------------------------------------------------------------------------- + + +def _is_optional(hint: Any) -> tuple[bool, Any]: + """Devuelve ``(nullable, inner_type)`` para ``X | None``.""" + origin = get_origin(hint) + if origin in (Union, types.UnionType): + args = [a for a in get_args(hint) if a is not type(None)] + if len(get_args(hint)) > len(args): + return True, args[0] if len(args) == 1 else hint + return False, hint + + +def infer_field(hint: Any, default: Any = _UNSET) -> Field: + """Construye un :class:`Field` a partir de la anotación del atributo.""" + nullable, inner = _is_optional(hint) + has_default = default is not _UNSET + f: Field + + if inner is str: + f = String() + elif inner is int: + f = Integer() + elif inner is bool: + f = Boolean() + elif inner is float: + f = FloatField() + elif inner is Decimal: + f = DecimalField() + elif inner is datetime: + f = DateTime() + elif inner is date: + f = DateField() + elif inner is bytes: + f = Field(sql_type="BLOB") + elif inner is dict or get_origin(inner) is dict or inner is list or get_origin(inner) is list: + f = Json() + else: + f = Field() + + f.nullable = nullable or (has_default and default is None) + if has_default: + f.default = default + return f diff --git a/shiba/orm/model.py b/shiba/orm/model.py new file mode 100644 index 0000000..f0be95f --- /dev/null +++ b/shiba/orm/model.py @@ -0,0 +1,314 @@ +"""Modelos POO con metaclass que lee anotaciones.""" +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar + +from shiba import error_codes +from shiba.orm.fields import _UNSET, Field, infer_field + +if TYPE_CHECKING: + from shiba import ShibaConnection + from shiba.core.query_builder import QueryBuilder + + +T = TypeVar("T", bound="Model") + + +_default_connection: ShibaConnection | None = None + + +def set_default_connection(connection: ShibaConnection) -> None: + """Registra la conexión global usada por modelos sin ``__db__``.""" + global _default_connection + _default_connection = connection + + +def get_default_connection() -> ShibaConnection | None: + return _default_connection + + +class ModelMeta(type): + """Metaclass que extrae ``_fields`` desde las anotaciones de la clase.""" + + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + ns: dict[str, Any], + ) -> ModelMeta: + cls = super().__new__(mcs, name, bases, ns) + + # No procesamos la propia clase base ``Model``. + if ns.get("__shiba_model_root__", False): + return cls + + # Recolectamos anotaciones por clase del MRO, evaluando strings + # con eval_str. Esto evita fallar si una clase ancestra tiene + # forward refs no resolvibles (p. ej. ``ShibaConnection``). + annotations: dict[str, Any] = {} + for klass in reversed(cls.__mro__): + if klass is object: + continue + try: + hints = inspect.get_annotations(klass, eval_str=True) + except (NameError, AttributeError): + hints = inspect.get_annotations(klass, eval_str=False) + annotations.update(hints) + + fields: dict[str, Field] = {} + for attr, hint in annotations.items(): + if attr.startswith("_") or attr in {"ClassVar"}: + continue + value = ns.get(attr, _UNSET) + fld = value if isinstance(value, Field) else infer_field(hint, default=value) + fields[attr] = fld + + # Quitamos el Field del namespace para que la lookup pase por + # la instancia y no devuelva el descriptor. + if attr in cls.__dict__ and isinstance(cls.__dict__[attr], Field): + delattr(cls, attr) + + cls._fields = fields # type: ignore[attr-defined] + cls._table = ns.get("__table__", name.lower()) # type: ignore[attr-defined] + return cls + + +class Model(metaclass=ModelMeta): + """Base de cualquier modelo. Se levanta como objeto Python plano.""" + + __shiba_model_root__: ClassVar[bool] = True + _fields: ClassVar[dict[str, Field]] = {} + _table: ClassVar[str] = "" + __db__: ClassVar[Any] = None # ShibaConnection | None — tipado en docstring + + # ------------------------------------------------------------------ + # Construcción + # ------------------------------------------------------------------ + + def __init__(self, **kwargs: Any) -> None: + unknown = set(kwargs) - set(self._fields) + if unknown: + error_codes.INVALID_DATA_FORMAT.raise_( + f"{type(self).__name__}: claves desconocidas {sorted(unknown)}", + details={"unknown": sorted(unknown)}, + ) + for attr, fld in self._fields.items(): + if attr in kwargs: + setattr(self, attr, kwargs[attr]) + elif fld.has_default(): + setattr(self, attr, fld.get_default()) + else: + setattr(self, attr, None) + + def __repr__(self) -> str: + pairs = ", ".join(f"{k}={getattr(self, k, None)!r}" for k in self._fields) + return f"{type(self).__name__}({pairs})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Model) or type(self) is not type(other): + return NotImplemented + return bool(self.to_dict() == other.to_dict()) + + def __hash__(self) -> int: # pragma: no cover - identidad por pk + pk_attr = self._pk_attr() + return hash((type(self).__name__, getattr(self, pk_attr, None))) + + # ------------------------------------------------------------------ + # Introspección + # ------------------------------------------------------------------ + + @classmethod + def _pk_attr(cls) -> str: + for attr, fld in cls._fields.items(): + if fld.primary_key: + return attr + error_codes.MISSING_REQUIRED_DATA.raise_( + f"{cls.__name__}: ningún campo marcado como primary_key." + ) + + def to_dict(self) -> dict[str, Any]: + return {attr: getattr(self, attr, None) for attr in self._fields} + + def to_db_dict(self, *, exclude_pk_if_none: bool = True) -> dict[str, Any]: + """Devuelve el dict listo para INSERT/UPDATE.""" + pk_attr = self._pk_attr() + out: dict[str, Any] = {} + for attr, fld in self._fields.items(): + value = getattr(self, attr, None) + if ( + attr == pk_attr + and exclude_pk_if_none + and (value is None or value == _UNSET) + ): + continue + col = fld.column_name or attr + out[col] = fld.to_db(value) + return out + + @classmethod + def from_row(cls: type[T], row: dict[str, Any]) -> T: + """Hidrata una instancia desde una fila ``dict``.""" + instance = cls.__new__(cls) + for attr, fld in cls._fields.items(): + col = fld.column_name or attr + raw = row.get(col) + setattr(instance, attr, fld.to_python(raw)) + return instance + + # ------------------------------------------------------------------ + # Conexión + # ------------------------------------------------------------------ + + @classmethod + def _connection(cls) -> ShibaConnection: + conn: ShibaConnection | None = cls.__db__ or _default_connection + if conn is None: + error_codes.CONNECTION_NOT_OPEN.raise_( + f"{cls.__name__} no tiene conexión. Llama a " + "shiba.set_default_connection(cx) o asigna `__db__`." + ) + return conn + + # ------------------------------------------------------------------ + # Query API a nivel de clase + # ------------------------------------------------------------------ + + @classmethod + def query(cls: type[T]) -> ModelQuery[T]: + return ModelQuery(cls) + + @classmethod + def all(cls: type[T]) -> list[T]: + return cls.query().get() + + @classmethod + def find(cls: type[T], pk_value: Any) -> T | None: + row = cls._connection().table(cls._table).find(pk_value, pk=cls._pk_attr()) + return cls.from_row(row) if row else None + + @classmethod + def where(cls: type[T], *args: Any) -> ModelQuery[T]: + return cls.query().where(*args) + + @classmethod + def first(cls: type[T]) -> T | None: + return cls.query().first() + + @classmethod + def count(cls) -> int: + return cls._connection().table(cls._table).count() + + # ------------------------------------------------------------------ + # Schema + # ------------------------------------------------------------------ + + @classmethod + def create_table(cls) -> None: + tb = cls._connection().create_table(cls._table) + for attr, fld in cls._fields.items(): + fld.apply_to_table_builder(tb, attr) + tb.build() + + @classmethod + def drop_table(cls) -> None: + cx = cls._connection() + cx.raw(f"DROP TABLE IF EXISTS {cx.dialect.quote_identifier(cls._table)}") + + @classmethod + def truncate_table(cls) -> None: + cls._connection().table(cls._table).truncate() + + # ------------------------------------------------------------------ + # Persistencia + # ------------------------------------------------------------------ + + def save(self: T) -> T: + pk_attr = self._pk_attr() + pk_val = getattr(self, pk_attr, None) + cx = self._connection() + data = self.to_db_dict(exclude_pk_if_none=True) + if pk_val is None: + cx.table(self._table).insert(data) + new_id = cx.raw("SELECT LAST_INSERT_ID() AS v") + if new_id and new_id[0].get("v"): + setattr(self, pk_attr, new_id[0]["v"]) + else: + cx.table(self._table).where(pk_attr, pk_val).update(data) + return self + + def delete(self) -> None: + pk_attr = self._pk_attr() + pk_val = getattr(self, pk_attr, None) + if pk_val is None: + error_codes.MISSING_REQUIRED_DATA.raise_( + f"delete(): {type(self).__name__} sin PK." + ) + self._connection().table(self._table).where(pk_attr, pk_val).delete() + + +# --------------------------------------------------------------------------- +# Manager hidratante +# --------------------------------------------------------------------------- + + +class ModelQuery(Generic[T]): + """Wrapper de :class:`QueryBuilder` que devuelve modelos en vez de dicts.""" + + def __init__(self, model_cls: type[T]) -> None: + self.model_cls = model_cls + cx = model_cls._connection() + self._qb: QueryBuilder = cx.table(model_cls._table) + + # Delegación al builder con retorno fluido. + def where(self, *args: Any) -> ModelQuery[T]: + self._qb.where(*args) + return self + + def or_where(self, *args: Any) -> ModelQuery[T]: + self._qb.or_where(*args) + return self + + def where_in(self, column: str, values: list[Any]) -> ModelQuery[T]: + self._qb.where_in(column, values) + return self + + def where_null(self, column: str) -> ModelQuery[T]: + self._qb.where_null(column) + return self + + def where_not_null(self, column: str) -> ModelQuery[T]: + self._qb.where_not_null(column) + return self + + def order_by(self, column: str, direction: str = "ASC") -> ModelQuery[T]: + self._qb.order_by(column, direction) + return self + + def limit(self, n: int) -> ModelQuery[T]: + self._qb.limit(n) + return self + + def offset(self, n: int) -> ModelQuery[T]: + self._qb.offset(n) + return self + + # Ejecución hidratada. + def get(self) -> list[T]: + rows = self._qb.get() + return [self.model_cls.from_row(r) for r in rows] + + def first(self) -> T | None: + row = self._qb.first() + return self.model_cls.from_row(row) if row else None + + def count(self) -> int: + return self._qb.count() + + def exists(self) -> bool: + return self._qb.exists() + + def paginate(self, page: int = 1, per_page: int = 25) -> dict[str, Any]: + result = self._qb.paginate(page, per_page) + result["data"] = [self.model_cls.from_row(r) for r in result["data"]] + return result diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..6ab329e --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,67 @@ +"""Fixtures que levantan MySQL y Postgres reales con Testcontainers. + +Cada test marcado con ``@pytest.mark.integration`` consume una de estas +conexiones reales. Los contenedores se reutilizan en toda la sesión. +""" +from __future__ import annotations + +import os +from collections.abc import Iterator + +import pytest + +import shiba + +try: + from testcontainers.mysql import MySqlContainer + from testcontainers.postgres import PostgresContainer +except ImportError: + MySqlContainer = None # type: ignore[assignment,misc] + PostgresContainer = None # type: ignore[assignment,misc] + + +_SKIP_INTEGRATION = os.getenv("SHIBA_SKIP_INTEGRATION") == "1" + + +@pytest.fixture(scope="session") +def mysql_url() -> Iterator[str]: + if _SKIP_INTEGRATION or MySqlContainer is None: + pytest.skip("integration tests deshabilitados o testcontainers no instalado") + with MySqlContainer("mysql:8.0", dialect="pymysql") as container: + yield container.get_connection_url() + + +@pytest.fixture(scope="session") +def postgres_url() -> Iterator[str]: + if _SKIP_INTEGRATION or PostgresContainer is None: + pytest.skip("integration tests deshabilitados o testcontainers no instalado") + with PostgresContainer("postgres:16-alpine", driver=None) as container: + yield container.get_connection_url() + + +def _normalize_to_shiba_dsn(url: str, scheme: str) -> str: + """``mysql+pymysql://...`` → ``mysql://...`` (Shiba ignora el driver hint).""" + if "+" in url.split("://", 1)[0]: + _, tail = url.split("://", 1) + return f"{scheme}://{tail}" + return url + + +@pytest.fixture +def mysql_cx(mysql_url: str) -> Iterator[shiba.ShibaConnection]: + dsn = _normalize_to_shiba_dsn(mysql_url, "mysql") + cx = shiba.connect(dsn) + try: + yield cx + finally: + cx.close() + + +@pytest.fixture +def postgres_cx(postgres_url: str) -> Iterator[shiba.ShibaConnection]: + dsn = _normalize_to_shiba_dsn(postgres_url, "postgres") + cx = shiba.connect(dsn) + try: + yield cx + finally: + cx.close() diff --git a/tests/integration/test_mysql_e2e.py b/tests/integration/test_mysql_e2e.py new file mode 100644 index 0000000..3655d02 --- /dev/null +++ b/tests/integration/test_mysql_e2e.py @@ -0,0 +1,93 @@ +"""Smoke end-to-end contra MySQL 8 real.""" +from __future__ import annotations + +import json + +import pytest + +pytestmark = pytest.mark.integration + + +def test_create_and_crud(mysql_cx) -> None: + mysql_cx.create_table("users").increments("id", primary_key=True).string( + "name", 64 + ).integer("age").json("settings").build() + try: + tbl = mysql_cx.table("users") + tbl.insert({"name": "Alice", "age": 30, "settings": json.dumps({"theme": "dark"})}) + tbl.insert({"name": "Bob", "age": 18, "settings": json.dumps({"theme": "light"})}) + + rows = mysql_cx.table("users").order_by("age").get() + assert [r["name"] for r in rows] == ["Bob", "Alice"] + + n = mysql_cx.table("users").where("age", ">=", 18).count() + assert n == 2 + + mysql_cx.table("users").where("name", "Bob").update({"age": 19}) + bob = mysql_cx.table("users").where("name", "Bob").first() + assert bob is not None and bob["age"] == 19 + + mysql_cx.table("users").where("name", "Bob").delete() + assert mysql_cx.table("users").count() == 1 + finally: + mysql_cx.raw("DROP TABLE IF EXISTS users") + + +def test_upsert(mysql_cx) -> None: + mysql_cx.raw("DROP TABLE IF EXISTS items") + mysql_cx.create_table("items").integer("id").primary().string("name").build() + try: + mysql_cx.table("items").upsert({"id": 1, "name": "A"}) + mysql_cx.table("items").upsert({"id": 1, "name": "B"}) # actualiza + rows = mysql_cx.table("items").get() + assert len(rows) == 1 + assert rows[0]["name"] == "B" + finally: + mysql_cx.raw("DROP TABLE IF EXISTS items") + + +def test_transaction_rollback(mysql_cx) -> None: + mysql_cx.raw("DROP TABLE IF EXISTS t") + mysql_cx.create_table("t").integer("id").primary().build() + try: + with pytest.raises(RuntimeError), mysql_cx.transaction(): + mysql_cx.table("t").insert({"id": 1}) + raise RuntimeError("abort") + assert mysql_cx.table("t").count() == 0 + + with mysql_cx.transaction(): + mysql_cx.table("t").insert({"id": 1}) + assert mysql_cx.table("t").count() == 1 + finally: + mysql_cx.raw("DROP TABLE IF EXISTS t") + + +def test_orm_end_to_end(mysql_cx) -> None: + from shiba import Model, fields, set_default_connection + + set_default_connection(mysql_cx) + + class Customer(Model): + __table__ = "customers" + id: int = fields.PrimaryKey() + name: str + active: bool = True + + mysql_cx.raw("DROP TABLE IF EXISTS customers") + Customer.create_table() + try: + Customer(name="Alice").save() + Customer(name="Bob", active=False).save() + + rows = Customer.where("active", True).get() + assert len(rows) == 1 + assert isinstance(rows[0], Customer) + assert rows[0].name == "Alice" + + first = Customer.find(1) + assert first is not None and first.name == "Alice" + first.name = "Alice Renamed" + first.save() + assert Customer.find(1).name == "Alice Renamed" # type: ignore[union-attr] + finally: + mysql_cx.raw("DROP TABLE IF EXISTS customers") diff --git a/tests/integration/test_postgres_e2e.py b/tests/integration/test_postgres_e2e.py new file mode 100644 index 0000000..f5dfde5 --- /dev/null +++ b/tests/integration/test_postgres_e2e.py @@ -0,0 +1,59 @@ +"""Smoke end-to-end contra Postgres 16 real.""" +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.integration + + +def test_create_and_crud(postgres_cx) -> None: + postgres_cx.raw("DROP TABLE IF EXISTS users") + postgres_cx.create_table("users").increments("id", primary_key=True).string( + "name", 64 + ).integer("age").json("settings").build() + try: + tbl = postgres_cx.table("users") + tbl.insert({"name": "Alice", "age": 30, "settings": '{"theme": "dark"}'}) + tbl.insert({"name": "Bob", "age": 18, "settings": '{"theme": "light"}'}) + + rows = postgres_cx.table("users").order_by("age").get() + assert [r["name"] for r in rows] == ["Bob", "Alice"] + + n = postgres_cx.table("users").where("age", ">=", 18).count() + assert n == 2 + + postgres_cx.table("users").where("name", "Bob").update({"age": 19}) + bob = postgres_cx.table("users").where("name", "Bob").first() + assert bob is not None and bob["age"] == 19 + + postgres_cx.table("users").where("name", "Bob").delete() + assert postgres_cx.table("users").count() == 1 + finally: + postgres_cx.raw("DROP TABLE IF EXISTS users") + + +def test_upsert_with_on(postgres_cx) -> None: + postgres_cx.raw("DROP TABLE IF EXISTS items") + postgres_cx.create_table("items").integer("id").primary().string("name").build() + try: + postgres_cx.table("items").upsert({"id": 1, "name": "A"}, on=["id"]) + postgres_cx.table("items").upsert({"id": 1, "name": "B"}, on=["id"]) + rows = postgres_cx.table("items").get() + assert len(rows) == 1 + assert rows[0]["name"] == "B" + finally: + postgres_cx.raw("DROP TABLE IF EXISTS items") + + +def test_identity_pk_generated(postgres_cx) -> None: + postgres_cx.raw("DROP TABLE IF EXISTS pk_test") + postgres_cx.create_table("pk_test").increments("id", primary_key=True).string( + "label" + ).build() + try: + postgres_cx.table("pk_test").insert({"label": "a"}) + postgres_cx.table("pk_test").insert({"label": "b"}) + ids = [r["id"] for r in postgres_cx.table("pk_test").order_by("id").get()] + assert ids == [1, 2] + finally: + postgres_cx.raw("DROP TABLE IF EXISTS pk_test") diff --git a/tests/test_connect_factory.py b/tests/test_connect_factory.py new file mode 100644 index 0000000..309a7c1 --- /dev/null +++ b/tests/test_connect_factory.py @@ -0,0 +1,81 @@ +"""`shiba.connect(dsn)` resuelve el dialecto desde la URL.""" +from __future__ import annotations + +import pytest + +import shiba +from shiba import error_codes +from shiba.errors import ShibaError + + +def test_unsupported_scheme_raises() -> None: + with pytest.raises(ShibaError) as ei: + shiba.connect("oracle://u:p@h/db") + assert ei.value.code is error_codes.NOT_IMPLEMENTED + + +def test_mysql_dsn_picks_mysql_dialect(monkeypatch) -> None: + """Sin levantar conexión real, verificamos que el factory escogería MySQL.""" + constructed: dict[str, object] = {} + + class FakeMySQLDb: + def __init__(self, host, port, user, password, *, database=None): + constructed.update( + kind="mysql", + host=host, + port=port, + user=user, + password=password, + database=database, + ) + + monkeypatch.setattr(shiba, "Database", FakeMySQLDb) + cx = shiba.connect("mysql://alice:secret@db.host:3307/app") + assert constructed["kind"] == "mysql" + assert constructed["host"] == "db.host" + assert constructed["port"] == 3307 + assert constructed["user"] == "alice" + assert constructed["database"] == "app" + assert cx.dialect.name == "mysql" + + +def test_postgres_dsn_picks_postgres_dialect(monkeypatch) -> None: + constructed: dict[str, object] = {} + + class FakePgDb: + def __init__(self, host, port, user, password, *, database=None): + constructed.update( + kind="postgres", + host=host, + port=port, + user=user, + password=password, + database=database, + ) + + # Reemplazamos el import perezoso del driver Postgres por el fake. + import shiba.dialects.postgres.driver as pg_driver + + monkeypatch.setattr(pg_driver, "Database", FakePgDb) + cx = shiba.connect("postgres://bob:hunter2@pg.host/app") + assert constructed["kind"] == "postgres" + assert constructed["host"] == "pg.host" + assert constructed["port"] == 5432 # default + assert constructed["user"] == "bob" + assert constructed["database"] == "app" + assert cx.dialect.name == "postgres" + + +def test_postgresql_alias() -> None: + """``postgresql://`` y ``postgres://`` deberían comportarse igual.""" + import shiba.dialects.postgres.driver as pg_driver + + calls = [] + + class FakePgDb: + def __init__(self, *a, **kw): + calls.append((a, kw)) + + pg_driver.Database = FakePgDb # type: ignore[misc] + cx = shiba.connect("postgresql://x:y@h/d") + assert cx.dialect.name == "postgres" diff --git a/tests/test_orm.py b/tests/test_orm.py new file mode 100644 index 0000000..a685c26 --- /dev/null +++ b/tests/test_orm.py @@ -0,0 +1,293 @@ +"""Cobertura del ORM (fields + Model + ModelQuery).""" +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal +from typing import Any + +import pytest + +from shiba import Model, error_codes, fields, set_default_connection +from shiba.core.query_builder import QueryBuilder +from shiba.core.table_builder import TableBuilder +from shiba.dialects.mysql import MySQLDialect +from shiba.errors import MissingDataError, ShibaError + +# --------------------------------------------------------------------------- +# Fake connection que reusa el FakeDatabase de conftest +# --------------------------------------------------------------------------- + + +class FakeShibaConnection: + def __init__(self, fake_db: Any) -> None: + self.db = fake_db + self.dialect = MySQLDialect() + + def table(self, name: str) -> QueryBuilder: + return QueryBuilder(self.db, name, dialect=self.dialect) + + def create_table(self, name: str) -> TableBuilder: + return TableBuilder(self.db, name, dialect=self.dialect) + + def raw(self, query: str, params: Any = None, *, many: bool = False): + return self.db.execute(query, params, many=many) + + +@pytest.fixture +def cx(fake_db) -> FakeShibaConnection: + conn = FakeShibaConnection(fake_db) + set_default_connection(conn) # type: ignore[arg-type] + return conn + + +# --------------------------------------------------------------------------- +# Field inference +# --------------------------------------------------------------------------- + + +def test_infer_types_from_annotations() -> None: + class T(Model): + __table__ = "t" + id: int = fields.PrimaryKey() + name: str + age: int | None = None + active: bool = True + score: float = 0.0 + amount: Decimal | None = None + settings: dict = fields.Json(default_factory=dict) + notes: str | None = None + created_at: datetime = fields.DateTime(default_now=True) + + f = T._fields + assert f["id"].primary_key and f["id"].auto_increment + assert f["name"].sql_type == "VARCHAR(255)" and not f["name"].nullable + assert f["age"].sql_type == "INT" and f["age"].nullable + assert f["active"].sql_type == "BOOLEAN" and f["active"].default is True + assert f["score"].sql_type == "FLOAT" + assert f["amount"].sql_type == "DECIMAL" and f["amount"].nullable + assert f["settings"].json and f["settings"].sql_type == "JSON" + assert f["notes"].nullable + assert f["created_at"].default_factory is not None + + +def test_unknown_kwarg_rejected(cx) -> None: + class T(Model): + __table__ = "t" + id: int = fields.PrimaryKey() + name: str + + with pytest.raises(ShibaError) as ei: + T(name="x", oops=1) + assert ei.value.code is error_codes.INVALID_DATA_FORMAT + + +def test_pk_required(cx) -> None: + class NoPk(Model): + __table__ = "nopk" + name: str + + with pytest.raises(MissingDataError) as ei: + NoPk(name="x").save() + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +# --------------------------------------------------------------------------- +# Persistencia +# --------------------------------------------------------------------------- + + +def test_save_insert(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + age: int | None = None + + # Para que save() interprete LAST_INSERT_ID le damos un retorno. + seq = iter([[], [{"v": 42}]]) + + def execute(query, params=None, **kwargs): + fake_db.calls.append((query, params, kwargs.get("many", False))) + return next(seq) + + fake_db.execute = execute # type: ignore[method-assign] + + u = User(name="John", age=30) + u.save() + assert u.id == 42 + insert_sql, params, _ = fake_db.calls[0] + assert insert_sql.startswith("INSERT INTO `users` (`name`, `age`)") + assert params == ("John", 30) + + +def test_save_update_when_pk_present(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + u = User(id=5, name="John") + u.save() + sql, params, _ = fake_db.last_call + assert sql == "UPDATE `users` SET `id` = %s, `name` = %s WHERE `id` = %s" + assert params == (5, "John", 5) + + +def test_delete(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + User(id=7, name="X").delete() + sql, params, _ = fake_db.last_call + assert sql == "DELETE FROM `users` WHERE `id` = %s" + assert params == (7,) + + +def test_delete_without_pk_raises(cx) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + with pytest.raises(MissingDataError) as ei: + User(name="X").delete() + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +# --------------------------------------------------------------------------- +# Lectura +# --------------------------------------------------------------------------- + + +def test_find_returns_model_instance(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + fake_db.result = [{"id": 3, "name": "Alice"}] + u = User.find(3) + assert isinstance(u, User) + assert u.id == 3 + assert u.name == "Alice" + + +def test_find_returns_none_when_missing(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + fake_db.result = [] + assert User.find(99) is None + + +def test_where_returns_modelquery(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + age: int + + fake_db.result = [ + {"id": 1, "name": "A", "age": 20}, + {"id": 2, "name": "B", "age": 30}, + ] + rows = User.where("age", ">", 10).order_by("age").get() + assert all(isinstance(r, User) for r in rows) + assert [r.name for r in rows] == ["A", "B"] + sql, params, _ = fake_db.last_call + assert "WHERE `age` > %s" in sql + assert "ORDER BY `age` ASC" in sql + assert params == (10,) + + +def test_json_field_roundtrips(cx, fake_db) -> None: + class Doc(Model): + __table__ = "docs" + id: int = fields.PrimaryKey() + payload: dict = fields.Json(default_factory=dict) + + fake_db.result = [{"id": 1, "payload": '{"a": 1}'}] + d = Doc.find(1) + assert d is not None + assert d.payload == {"a": 1} + + # Y al guardar, payload se serializa. + fake_db.result = [] + Doc(payload={"x": "y"}).save() + insert_calls = [c for c in fake_db.calls if c[0].startswith("INSERT INTO `docs`")] + assert insert_calls, "esperaba al menos un INSERT" + _, params, _ = insert_calls[-1] + assert params == ('{"x": "y"}',) + + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + + +def test_create_table_emits_ddl(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str = fields.String(max_length=50) + email: str = fields.String(unique=True) + age: int | None = None + settings: dict = fields.Json(default_factory=dict) + + User.create_table() + sql, _, _ = fake_db.last_call + assert "CREATE TABLE IF NOT EXISTS `users`" in sql + assert "`id` INT AUTO_INCREMENT PRIMARY KEY" in sql + assert "`name` VARCHAR(50) NOT NULL" in sql + assert "`email` VARCHAR(255) UNIQUE NOT NULL" in sql + assert "`age` INT NULL" in sql + assert "`settings` JSON NOT NULL" in sql + + +def test_truncate_table(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + User.truncate_table() + sql, _, _ = fake_db.last_call + assert sql == "TRUNCATE TABLE `users`" + + +def test_default_table_name_from_class(cx) -> None: + class Customer(Model): + id: int = fields.PrimaryKey() + name: str + + assert Customer._table == "customer" + + +def test_to_dict_and_from_row_roundtrip(cx) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + u = User.from_row({"id": 1, "name": "X"}) + assert u.to_dict() == {"id": 1, "name": "X"} + + +def test_no_connection_raises(monkeypatch) -> None: + from shiba.orm import model as model_mod + + monkeypatch.setattr(model_mod, "_default_connection", None) + + class T(Model): + __table__ = "t" + id: int = fields.PrimaryKey() + name: str + + T.__db__ = None + with pytest.raises(ShibaError) as ei: + T.find(1) + assert ei.value.code is error_codes.CONNECTION_NOT_OPEN diff --git a/tests/test_postgres_dialect.py b/tests/test_postgres_dialect.py new file mode 100644 index 0000000..010eb7d --- /dev/null +++ b/tests/test_postgres_dialect.py @@ -0,0 +1,92 @@ +"""PostgresDialect emite SQL en la sintaxis idiomática del motor.""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.core.query_builder import QueryBuilder +from shiba.core.table_builder import TableBuilder +from shiba.dialects.postgres.dialect import PostgresDialect +from shiba.errors import MissingDataError, SchemaError + + +@pytest.fixture +def pg() -> PostgresDialect: + return PostgresDialect() + + +def test_quotes_with_double_quotes(pg) -> None: + assert pg.quote_identifier("users") == '"users"' + assert pg.quote_identifier("schema.table") == '"schema"."table"' + + +def test_quote_rejects_injection(pg) -> None: + with pytest.raises(SchemaError) as ei: + pg.quote_identifier('users"; DROP TABLE x; --') + assert ei.value.code is error_codes.INVALID_IDENTIFIER + + +def test_select_uses_double_quotes(fake_db, pg) -> None: + QueryBuilder(fake_db, "users", dialect=pg).where("name", "John").get() + sql, params, _ = fake_db.last_call + assert sql.startswith('SELECT * FROM "users"') + assert 'WHERE "name" = %s' in sql + assert params == ("John",) + + +def test_create_table_uses_identity_pk(fake_db, pg) -> None: + sql = ( + TableBuilder(fake_db, "users", dialect=pg) + .increments("id", primary_key=True) + .string("name") + .json("settings") + .to_sql() + ) + assert 'CREATE TABLE IF NOT EXISTS "users"' in sql + assert '"id" INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY' in sql + assert '"name" VARCHAR(255)' in sql + assert '"settings" JSONB' in sql # JSON → JSONB + + +def test_datetime_maps_to_timestamp(fake_db, pg) -> None: + sql = TableBuilder(fake_db, "t", dialect=pg).datetime("created_at").to_sql() + assert '"created_at" TIMESTAMP' in sql + assert "DATETIME" not in sql + + +def test_blob_maps_to_bytea(fake_db, pg) -> None: + sql = TableBuilder(fake_db, "t", dialect=pg).binary("data").to_sql() + assert '"data" BYTEA' in sql + + +def test_upsert_emits_on_conflict(fake_db, pg) -> None: + QueryBuilder(fake_db, "users", dialect=pg).upsert( + {"id": 1, "name": "X"}, on=["id"] + ) + sql, params, _ = fake_db.last_call + assert sql == ( + 'INSERT INTO "users" ("id", "name") VALUES (%s, %s) ' + 'ON CONFLICT ("id") DO UPDATE SET "id" = EXCLUDED."id", "name" = EXCLUDED."name"' + ) + assert params == (1, "X") + + +def test_upsert_without_on_in_postgres_raises(fake_db, pg) -> None: + with pytest.raises(MissingDataError) as ei: + QueryBuilder(fake_db, "users", dialect=pg).upsert({"id": 1, "name": "X"}) + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +def test_upsert_empty_update_does_nothing(fake_db, pg) -> None: + QueryBuilder(fake_db, "users", dialect=pg).upsert( + {"id": 1, "name": "X"}, update=[], on=["id"] + ) + sql, _, _ = fake_db.last_call + assert 'ON CONFLICT ("id") DO NOTHING' in sql + + +def test_mysql_upsert_still_works_without_on(fake_db, dialect) -> None: + """`on` es opcional en MySQL (lo detecta la PK).""" + QueryBuilder(fake_db, "users", dialect=dialect).upsert({"id": 1, "name": "X"}) + sql, _, _ = fake_db.last_call + assert "ON DUPLICATE KEY UPDATE" in sql diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index e1a3539..572f4a8 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -132,8 +132,8 @@ def test_delete_with_where(fake_db, dialect) -> None: def test_count(fake_db, dialect) -> None: - fake_db.result = [{"cnt": 42}] + fake_db.result = [{"v": 42}] n = QueryBuilder(fake_db, "users", dialect=dialect).where("active", True).count() assert n == 42 sql, _, _ = fake_db.last_call - assert sql.startswith("SELECT COUNT(*) AS cnt FROM `users`") + assert sql.startswith("SELECT COUNT(*) AS v FROM `users`") diff --git a/tests/test_query_builder_fase1.py b/tests/test_query_builder_fase1.py new file mode 100644 index 0000000..8d9a59c --- /dev/null +++ b/tests/test_query_builder_fase1.py @@ -0,0 +1,242 @@ +"""Cobertura de features añadidas en Fase 1. + +WHERE: or_where, where_in/null/between/like, where_group, where_json. +GROUP BY / HAVING. paginate / chunk / iterate. find / exists / pluck. +sum/avg/min/max. upsert. truncate. raw(). +""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.core.query_builder import QueryBuilder +from shiba.errors import QueryError, SchemaError + +# --------------------------------------------------------------------------- +# WHERE variants +# --------------------------------------------------------------------------- + +def test_or_where(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).where("a", 1).or_where("b", 2).get() + sql, params, _ = fake_db.last_call + assert "WHERE `a` = %s OR `b` = %s" in sql + assert params == (1, 2) + + +def test_where_in_alias(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_in("id", [1, 2]).get() + sql, params, _ = fake_db.last_call + assert "WHERE `id` IN (%s, %s)" in sql + assert params == (1, 2) + + +def test_where_null(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_null("deleted_at").get() + sql, _, _ = fake_db.last_call + assert "WHERE `deleted_at` IS NULL" in sql + + +def test_where_not_null(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_not_null("email").get() + sql, _, _ = fake_db.last_call + assert "WHERE `email` IS NOT NULL" in sql + + +def test_where_between(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_between("age", 18, 65).get() + sql, params, _ = fake_db.last_call + assert "WHERE `age` BETWEEN %s AND %s" in sql + assert params == (18, 65) + + +def test_where_like(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_like("name", "John%").get() + sql, params, _ = fake_db.last_call + assert "WHERE `name` LIKE %s" in sql + assert params == ("John%",) + + +def test_where_group_mixes_and_or(fake_db, dialect) -> None: + ( + QueryBuilder(fake_db, "u", dialect=dialect) + .where("active", True) + .where_group(lambda g: g.where("role", "admin").or_where("role", "owner")) + .get() + ) + sql, params, _ = fake_db.last_call + assert "WHERE `active` = %s AND (`role` = %s OR `role` = %s)" in sql + assert params == (True, "admin", "owner") + + +# --------------------------------------------------------------------------- +# JSON +# --------------------------------------------------------------------------- + +def test_where_json_extracts_path(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).where_json("settings", "$.theme", "dark").get() + sql, params, _ = fake_db.last_call + assert "JSON_UNQUOTE(JSON_EXTRACT(`settings`, '$.theme')) = %s" in sql + assert params == ("dark",) + + +def test_where_json_rejects_unsafe_path(fake_db, dialect) -> None: + with pytest.raises(QueryError) as ei: + QueryBuilder(fake_db, "u", dialect=dialect).where_json( + "settings", "$.theme'; DROP TABLE x; --", "dark" + ) + assert ei.value.code is error_codes.INVALID_QUERY_PARAMS + + +# --------------------------------------------------------------------------- +# GROUP BY / HAVING +# --------------------------------------------------------------------------- + +def test_group_by_and_having(fake_db, dialect) -> None: + fake_db.result = [{"v": 0}] + ( + QueryBuilder(fake_db, "orders", dialect=dialect) + .select("user_id") + .group_by("user_id") + .having("total", ">", 100) + .get() + ) + sql, params, _ = fake_db.last_call + assert "GROUP BY `user_id`" in sql + assert "HAVING `total` > %s" in sql + assert params == (100,) + + +# --------------------------------------------------------------------------- +# Aggregates +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "method,fn", + [("sum", "SUM"), ("avg", "AVG"), ("min", "MIN"), ("max", "MAX")], +) +def test_aggregates(fake_db, dialect, method: str, fn: str) -> None: + fake_db.result = [{"v": 7}] + qb = QueryBuilder(fake_db, "u", dialect=dialect) + result = getattr(qb, method)("age") + assert result == 7 + sql, _, _ = fake_db.last_call + assert f"SELECT {fn}(`age`) AS v FROM `users`".replace("`users`", "`u`") in sql + + +# --------------------------------------------------------------------------- +# Convenience reads +# --------------------------------------------------------------------------- + +def test_find_by_pk(fake_db, dialect) -> None: + fake_db.result = [{"id": 5, "name": "X"}] + row = QueryBuilder(fake_db, "u", dialect=dialect).find(5) + assert row == {"id": 5, "name": "X"} + sql, params, _ = fake_db.last_call + assert "WHERE `id` = %s" in sql + assert "LIMIT 1" in sql + assert params == (5,) + + +def test_exists_true_false(fake_db, dialect) -> None: + fake_db.result = [{"v": 3}] + assert QueryBuilder(fake_db, "u", dialect=dialect).exists() is True + fake_db.result = [{"v": 0}] + assert QueryBuilder(fake_db, "u", dialect=dialect).exists() is False + + +def test_pluck(fake_db, dialect) -> None: + fake_db.result = [{"name": "A"}, {"name": "B"}] + names = QueryBuilder(fake_db, "u", dialect=dialect).pluck("name") + assert names == ["A", "B"] + sql, _, _ = fake_db.last_call + assert sql.startswith("SELECT `name` FROM `u`") + + +# --------------------------------------------------------------------------- +# Paginate +# --------------------------------------------------------------------------- + +def test_paginate(fake_db, dialect) -> None: + # Primer execute = COUNT; segundo = data. + results = iter([[{"v": 53}], [{"id": i} for i in range(1, 26)]]) + + def execute(query, params=None, **kwargs): + fake_db.calls.append((query, params, kwargs.get("many", False))) + return next(results) + + fake_db.execute = execute # type: ignore[method-assign] + + page = QueryBuilder(fake_db, "u", dialect=dialect).paginate(page=2, per_page=25) + assert page["page"] == 2 + assert page["per_page"] == 25 + assert page["total"] == 53 + assert page["last_page"] == 3 + assert len(page["data"]) == 25 + # La segunda llamada debe tener LIMIT 25 OFFSET 25. + data_sql, _, _ = fake_db.calls[-1] + assert "LIMIT 25" in data_sql + assert "OFFSET 25" in data_sql + + +def test_paginate_rejects_zero(fake_db, dialect) -> None: + with pytest.raises(QueryError) as ei: + QueryBuilder(fake_db, "u", dialect=dialect).paginate(page=0, per_page=10) + assert ei.value.code is error_codes.INVALID_QUERY_PARAMS + + +# --------------------------------------------------------------------------- +# Chunk / iterate +# --------------------------------------------------------------------------- + +def test_chunk_iterates_until_empty(fake_db, dialect) -> None: + pages = iter( + [ + [{"id": 1}, {"id": 2}], + [{"id": 3}, {"id": 4}], + [{"id": 5}], + [], # nunca debería pedirse pero por seguridad + ] + ) + + def execute(query, params=None, **kwargs): + fake_db.calls.append((query, params, kwargs.get("many", False))) + return next(pages) + + fake_db.execute = execute # type: ignore[method-assign] + + collected: list[dict[str, int]] = [] + QueryBuilder(fake_db, "u", dialect=dialect).chunk(2, collected.extend) + assert [r["id"] for r in collected] == [1, 2, 3, 4, 5] + + +# --------------------------------------------------------------------------- +# UPSERT / TRUNCATE / RAW +# --------------------------------------------------------------------------- + +def test_upsert_emits_on_duplicate_key(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).upsert({"id": 1, "name": "X"}) + sql, params, _ = fake_db.last_call + assert sql == ( + "INSERT INTO `u` (`id`, `name`) VALUES (%s, %s) " + "ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `name` = VALUES(`name`)" + ) + assert params == (1, "X") + + +def test_upsert_with_explicit_update_columns(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).upsert( + {"id": 1, "name": "X", "age": 30}, update=["name", "age"] + ) + sql, _, _ = fake_db.last_call + assert "ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `age` = VALUES(`age`)" in sql + + +def test_truncate(fake_db, dialect) -> None: + QueryBuilder(fake_db, "u", dialect=dialect).truncate() + sql, _, _ = fake_db.last_call + assert sql == "TRUNCATE TABLE `u`" + + +def test_invalid_column_in_pluck(fake_db, dialect) -> None: + with pytest.raises(SchemaError): + QueryBuilder(fake_db, "u", dialect=dialect).pluck("name; DROP")