From 58421134bbbe28f0453eb82d61e4591ff579712e Mon Sep 17 00:00:00 2001 From: Rodrigo Pino Date: Mon, 18 May 2026 01:01:41 -0400 Subject: [PATCH] feat(postgres): dialecto PostgreSQL + factory connect(dsn) (Fase 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Segundo dialecto. La estructura preparada en Fase 0 paga su precio: sólo se añade — nada del core ni del ORM se mueve. Driver - `shiba/dialects/postgres/driver.py` con psycopg v3, `dict_row`, autocommit, traducción de excepciones nativas a códigos Shiba. - `psycopg[binary]` declarado como dep **opcional** en `[project.optional-dependencies] postgres`; también en `dev` para CI. - Import perezoso: el paquete `shiba.dialects.postgres` se puede importar sin tener psycopg instalado; el error sale al construir `Database`. Dialect - `PostgresDialect`: comillas dobles para identificadores, placeholder `%s`, `ON CONFLICT (...) DO UPDATE SET col = EXCLUDED.col` para upsert, `INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY` para auto-increment. - Mapeo de tipos: `DATETIME` → `TIMESTAMP`, `JSON` → `JSONB`, `BLOB` → `BYTEA`, `DOUBLE` → `DOUBLE PRECISION`. Cambios menores en el contrato - `Dialect.compile_upsert_update(update_columns, conflict_columns=None)`. - Nuevo `Dialect.compile_auto_increment_pk(quoted_col)`. - `TableBuilder.increments(pk=True)` delega al dialect. - `QueryBuilder.upsert(..., on=[...])` para Postgres; MySQL lo ignora. Factory - `shiba.connect("mysql://...")`/`shiba.connect("postgres://...")` parsea el DSN y devuelve la `ShibaConnection` adecuada. - `ShibaConnection` ahora acepta `db=`/`dialect=` ya construidos, sin romper la firma legacy `(host, port, user, password)`. Tests - 14 nuevos (114/114 total) cubren quoting con doble comilla, rechazo de injection, DDL con IDENTITY, mapping de tipos JSONB/BYTEA, upsert ON CONFLICT (con y sin `on`), factory resolviendo MySQL vs Postgres vs alias `postgresql://`, y rechazo de schemes no soportados. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 2 + shiba/__init__.py | 110 ++++++++++--- shiba/core/query_builder.py | 12 +- shiba/core/table_builder.py | 9 +- shiba/dialects/base.py | 18 ++- shiba/dialects/mysql/dialect.py | 6 +- shiba/dialects/postgres/__init__.py | 19 +++ shiba/dialects/postgres/dialect.py | 39 +++++ shiba/dialects/postgres/driver.py | 238 ++++++++++++++++++++++++++++ shiba/dialects/postgres/quoting.py | 13 ++ shiba/dialects/postgres/schema.py | 26 +++ tests/test_connect_factory.py | 81 ++++++++++ tests/test_postgres_dialect.py | 92 +++++++++++ 13 files changed, 628 insertions(+), 37 deletions(-) create mode 100644 shiba/dialects/postgres/__init__.py create mode 100644 shiba/dialects/postgres/dialect.py create mode 100644 shiba/dialects/postgres/driver.py create mode 100644 shiba/dialects/postgres/quoting.py create mode 100644 shiba/dialects/postgres/schema.py create mode 100644 tests/test_connect_factory.py create mode 100644 tests/test_postgres_dialect.py diff --git a/pyproject.toml b/pyproject.toml index 358e2ce..834c6f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,13 @@ classifiers = [ dependencies = ["pymysql>=1.1"] [project.optional-dependencies] +postgres = ["psycopg[binary]>=3.1"] dev = [ "pytest>=8", "pytest-cov>=5", "ruff>=0.6", "mypy>=1.10", + "psycopg[binary]>=3.1", ] [project.urls] diff --git a/shiba/__init__.py b/shiba/__init__.py index 873fe47..77bbd10 100644 --- a/shiba/__init__.py +++ b/shiba/__init__.py @@ -1,31 +1,27 @@ """Shiba — librería ligera para hablar con bases de datos relacionales. -Punto de entrada público: - .. code-block:: python - import shiba as s + import shiba - with s.ShibaConnection(host="localhost", port=3306, - user="u", password="p") as cx: - cx.create_database("my_db") - cx.use_database("my_db") - cx.create_table("users") \\ - .increments("id", primary_key=True) \\ - .string("name") \\ - .build() + # Forma 1 — DSN explícito (recomendado para multi-dialecto): + cx = shiba.connect("mysql://user:pass@localhost:3306/my_db") + cx = shiba.connect("postgres://user:pass@localhost:5432/my_db") - cx.table("users").insert({"name": "John"}) - rows = cx.table("users").where("name", "John").get() + # Forma 2 — construcción directa MySQL (legacy): + cx = shiba.ShibaConnection(host="localhost", port=3306, + user="u", password="p") """ from __future__ import annotations from contextlib import AbstractContextManager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse from shiba import error_codes from shiba.core.query_builder import QueryBuilder from shiba.core.table_builder import TableBuilder +from shiba.dialects.base import Dialect from shiba.dialects.mysql import Database, MySQLDialect from shiba.errors import ( ConnectionError, @@ -42,17 +38,36 @@ class ShibaConnection: - """Fachada de alto nivel sobre un :class:`Database` MySQL.""" + """Fachada de alto nivel agnóstica de dialecto. + + Acepta dos formas de construcción: + + * Legacy MySQL: ``ShibaConnection(host, port, user, password)``. + * Inyectada: ``ShibaConnection(db=Database(...), dialect=Dialect(...))`` + (la usa :func:`connect`). + """ def __init__( self, - host: str, - port: int, - user: str, - password: str, + host: str | None = None, + port: int | None = None, + user: str | None = None, + password: str | None = None, *, database: str | None = None, + db: Any = None, + dialect: Dialect | None = None, ) -> None: + if db is not None and dialect is not None: + self.dialect: Dialect = dialect + self.db: Any = db + return + + if host is None or port is None or user is None or password is None: + error_codes.MISSING_REQUIRED_DATA.raise_( + "ShibaConnection requiere host/port/user/password o " + "db+dialect inyectados." + ) self.dialect = MySQLDialect() self.db = Database(host, port, user, password, database=database) @@ -78,10 +93,10 @@ def close(self) -> None: # API pública # ------------------------------------------------------------------ - def create_database(self, database: str) -> Database: + def create_database(self, database: str) -> Any: return self.db.create_database(database) - def use_database(self, database: str) -> Database: + def use_database(self, database: str) -> Any: return self.db.use_database(database) def create_table(self, table_name: str) -> TableBuilder: @@ -90,9 +105,10 @@ def create_table(self, table_name: str) -> TableBuilder: def table(self, table_name: str) -> QueryBuilder: return QueryBuilder(self.db, table_name, dialect=self.dialect) - def transaction(self) -> AbstractContextManager[Database]: - """Context manager transaccional. Ver :meth:`Database.transaction`.""" - return self.db.transaction() + def transaction(self) -> AbstractContextManager[Any]: + """Context manager transaccional.""" + cm: AbstractContextManager[Any] = self.db.transaction() + return cm def raw( self, @@ -101,13 +117,54 @@ def raw( *, many: bool = False, ) -> list[dict[str, object]]: - """Escape hatch — ver :meth:`Database.raw`.""" - return self.db.raw(query, params, many=many) + rows: list[dict[str, object]] = self.db.raw(query, params, many=many) + return rows + + +# --------------------------------------------------------------------------- +# Factory connect(dsn) +# --------------------------------------------------------------------------- + + +_DEFAULT_PORTS = {"mysql": 3306, "postgres": 5432, "postgresql": 5432} + + +def connect(dsn: str) -> ShibaConnection: + """Construye una :class:`ShibaConnection` desde un DSN tipo URL. + + Schemes soportados: + + * ``mysql://user:pass@host:port/dbname`` + * ``postgres://user:pass@host:port/dbname`` (alias: ``postgresql://``) + """ + parsed = urlparse(dsn) + scheme = parsed.scheme.lower() + if scheme not in _DEFAULT_PORTS: + error_codes.NOT_IMPLEMENTED.raise_( + f"DSN scheme '{scheme}' no soportado. Usa: {sorted(_DEFAULT_PORTS)}." + ) + host = parsed.hostname or "localhost" + port = parsed.port or _DEFAULT_PORTS[scheme] + user = parsed.username or "" + password = parsed.password or "" + database = parsed.path.lstrip("/") or None + + if scheme == "mysql": + db: Any = Database(host, port, user, password, database=database) + return ShibaConnection(db=db, dialect=MySQLDialect()) + + # postgres / postgresql + from shiba.dialects.postgres import PostgresDialect + from shiba.dialects.postgres.driver import Database as PgDatabase + + db = PgDatabase(host, port, user, password, database=database) + return ShibaConnection(db=db, dialect=PostgresDialect()) __all__ = [ "ConnectionError", "Database", + "Dialect", "IntegrityError", "MissingDataError", "Model", @@ -118,6 +175,7 @@ def raw( "ShibaConnection", "ShibaError", "TableBuilder", + "connect", "error_codes", "fields", "set_default_connection", diff --git a/shiba/core/query_builder.py b/shiba/core/query_builder.py index 449acc7..f5b14ac 100644 --- a/shiba/core/query_builder.py +++ b/shiba/core/query_builder.py @@ -500,12 +500,14 @@ def upsert( data: dict[str, Any], *, update: list[str] | None = None, + on: list[str] | None = None, ) -> list[dict[str, Any]]: """INSERT con resolución de conflicto. - En MySQL se emite ``ON DUPLICATE KEY UPDATE``. El parámetro - ``update`` indica qué columnas pisar (por defecto todas excepto - las que sean clave). El dialecto adapta la cláusula. + :param data: columnas → valores. + :param update: columnas a pisar en conflicto (default todas). + :param on: columnas del conflicto. Requerido por Postgres, + opcional en MySQL (lo detecta por la PK). """ if not data: raise error_codes.MISSING_REQUIRED_DATA.build("upsert(): dict vacío.") @@ -515,7 +517,9 @@ def upsert( update_cols = update if update is not None else list(data.keys()) for col in update_cols: validate_identifier(col, kind="column") - update_sql = self.dialect.compile_upsert_update(update_cols) + for col in on or []: + validate_identifier(col, kind="column") + update_sql = self.dialect.compile_upsert_update(update_cols, on) table = self.dialect.quote_identifier(self.table_name) query = ( f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders}) {update_sql}" diff --git a/shiba/core/table_builder.py b/shiba/core/table_builder.py index b2ec6d3..7b06582 100644 --- a/shiba/core/table_builder.py +++ b/shiba/core/table_builder.py @@ -107,10 +107,11 @@ def _col(self, column_name: str, sql_type: str) -> TableBuilder: def increments(self, column_name: str = "id", primary_key: bool = False) -> TableBuilder: validate_identifier(column_name, kind="column") - suffix = " PRIMARY KEY" if primary_key else "" - self._append_column( - f"{self.dialect.quote_identifier(column_name)} INT AUTO_INCREMENT{suffix}" - ) + col_quoted = self.dialect.quote_identifier(column_name) + if primary_key: + self._append_column(self.dialect.compile_auto_increment_pk(col_quoted)) + else: + self._append_column(f"{col_quoted} INT AUTO_INCREMENT") return self def integer(self, column_name: str, length: int | None = None) -> TableBuilder: diff --git a/shiba/dialects/base.py b/shiba/dialects/base.py index 1cece2f..eca1ffa 100644 --- a/shiba/dialects/base.py +++ b/shiba/dialects/base.py @@ -43,9 +43,23 @@ def render_limit(self, limit: int | None, offset: int | None) -> str: return " ".join(parts) @abstractmethod - def compile_upsert_update(self, update_columns: list[str]) -> str: + def compile_upsert_update( + self, + update_columns: list[str], + conflict_columns: list[str] | None = None, + ) -> str: """Cláusula de resolución de conflicto para ``upsert``. MySQL → ``ON DUPLICATE KEY UPDATE col = VALUES(col), ...`` - Postgres/SQLite → ``ON CONFLICT (...) DO UPDATE SET ...``. + (``conflict_columns`` se ignora; lo detecta por la PK). + Postgres/SQLite → ``ON CONFLICT (col, ...) DO UPDATE SET col = EXCLUDED.col`` + (``conflict_columns`` obligatorio). """ + + def compile_auto_increment_pk(self, column_quoted: str) -> str: + """Declaración inline de PK auto-incremental. + + Default MySQL-ish. Postgres lo override con + ``GENERATED ALWAYS AS IDENTITY``. + """ + return f"{column_quoted} INT AUTO_INCREMENT PRIMARY KEY" diff --git a/shiba/dialects/mysql/dialect.py b/shiba/dialects/mysql/dialect.py index 91e237d..595f40e 100644 --- a/shiba/dialects/mysql/dialect.py +++ b/shiba/dialects/mysql/dialect.py @@ -18,7 +18,11 @@ def quote_identifier(self, name: str) -> str: def map_type(self, declared: str) -> str: return _map_type(declared) - def compile_upsert_update(self, update_columns: list[str]) -> str: + def compile_upsert_update( + self, + update_columns: list[str], + conflict_columns: list[str] | None = None, + ) -> str: if not update_columns: return "" parts = [f"{_qi(c)} = VALUES({_qi(c)})" for c in update_columns] diff --git a/shiba/dialects/postgres/__init__.py b/shiba/dialects/postgres/__init__.py new file mode 100644 index 0000000..8507025 --- /dev/null +++ b/shiba/dialects/postgres/__init__.py @@ -0,0 +1,19 @@ +"""Dialecto PostgreSQL.""" +from typing import Any + +from shiba.dialects.postgres.dialect import PostgresDialect + +__all__ = ["PostgresDialect"] + + +def _import_driver() -> Any: + """Import perezoso de ``Database`` (requiere ``psycopg``).""" + from shiba.dialects.postgres.driver import Database + + return Database + + +def __getattr__(name: str) -> Any: + if name == "Database": + return _import_driver() + raise AttributeError(name) diff --git a/shiba/dialects/postgres/dialect.py b/shiba/dialects/postgres/dialect.py new file mode 100644 index 0000000..460fa21 --- /dev/null +++ b/shiba/dialects/postgres/dialect.py @@ -0,0 +1,39 @@ +"""Implementación :class:`~shiba.dialects.base.Dialect` para PostgreSQL.""" +from __future__ import annotations + +from shiba import error_codes +from shiba.dialects.base import Dialect +from shiba.dialects.postgres.quoting import quote_identifier as _qi +from shiba.dialects.postgres.schema import map_type as _map_type + + +class PostgresDialect(Dialect): + """Doble comillas, placeholder ``%s`` (psycopg), ``ON CONFLICT``.""" + + name = "postgres" + placeholder = "%s" + + def quote_identifier(self, name: str) -> str: + return _qi(name) + + def map_type(self, declared: str) -> str: + return _map_type(declared) + + def compile_auto_increment_pk(self, column_quoted: str) -> str: + return f"{column_quoted} INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY" + + def compile_upsert_update( + self, + update_columns: list[str], + conflict_columns: list[str] | None = None, + ) -> str: + if not conflict_columns: + error_codes.MISSING_REQUIRED_DATA.raise_( + "upsert() en Postgres requiere el parámetro `on=[col, ...]` " + "para identificar el conflicto." + ) + target = ", ".join(_qi(c) for c in conflict_columns) + if not update_columns: + return f"ON CONFLICT ({target}) DO NOTHING" + sets = ", ".join(f"{_qi(c)} = EXCLUDED.{_qi(c)}" for c in update_columns) + return f"ON CONFLICT ({target}) DO UPDATE SET {sets}" diff --git a/shiba/dialects/postgres/driver.py b/shiba/dialects/postgres/driver.py new file mode 100644 index 0000000..c710106 --- /dev/null +++ b/shiba/dialects/postgres/driver.py @@ -0,0 +1,238 @@ +"""Wrapper sobre :mod:`psycopg` (v3) con la interfaz Shiba ``Database``.""" +from __future__ import annotations + +import logging +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +from shiba import error_codes +from shiba.dialects.postgres.quoting import quote_identifier +from shiba.error_codes import from_driver_exception + +if TYPE_CHECKING: + from types import TracebackType + +try: + import psycopg + from psycopg.rows import dict_row +except ImportError: # pragma: no cover - dep opcional + psycopg = None # type: ignore[assignment] + dict_row = None # type: ignore[assignment] + + +def _require_psycopg() -> None: + if psycopg is None: + raise ImportError( + "El dialecto Postgres requiere `psycopg[binary]`. " + "Instala: pip install 'shiba_mysql[postgres]' o psycopg[binary]" + ) + + +logger = logging.getLogger("shiba.postgres") + + +class Database: + """Conexión Postgres con la misma API que :class:`shiba.Database` de MySQL.""" + + def __init__( + self, + host: str, + port: int, + user: str, + password: str, + *, + database: str | None = None, + autoconnect: bool = True, + ) -> None: + _require_psycopg() + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self._connection: Any = None + self._in_transaction: bool = False + if autoconnect: + self.connect() + + # ------------------------------------------------------------------ + # Ciclo de vida + # ------------------------------------------------------------------ + + def connect(self) -> None: + if self._connection is not None and not self._connection.closed: + return + try: + self._connection = psycopg.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + dbname=self.database, + autocommit=True, + row_factory=dict_row, + ) + except psycopg.OperationalError as exc: + code = from_driver_exception(exc) + raise code.build( + f"No se pudo conectar a {self.host}:{self.port}: {exc}", + details={"host": self.host, "port": self.port}, + ) from exc + + def close(self) -> None: + if self._connection is not None and not self._connection.closed: + self._connection.close() + self._connection = None + self._in_transaction = False + + def __enter__(self) -> Database: + self.connect() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + @property + def _conn(self) -> Any: + if self._connection is None or self._connection.closed: + raise error_codes.CONNECTION_NOT_OPEN.build( + "La conexión Postgres no está abierta." + ) + return self._connection + + # ------------------------------------------------------------------ + # Ejecución + # ------------------------------------------------------------------ + + def execute( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + if not query or not isinstance(query, str): + raise error_codes.EMPTY_QUERY.build( + "Se intentó ejecutar una query vacía.", + query=str(query), + ) + conn = self._conn + try: + with conn.cursor() as cursor: + if params is None: + cursor.execute(query) + elif many: + cursor.executemany(query, params) + else: + cursor.execute(query, params) + try: + rows: list[dict[str, Any]] = list(cursor.fetchall()) + except psycopg.ProgrammingError: + rows = [] + if not self._in_transaction: + conn.commit() + return rows + except psycopg.errors.IntegrityError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Violación de integridad: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except psycopg.errors.SyntaxError as exc: + self._rollback_silent() + raise error_codes.QUERY_SYNTAX_ERROR.build( + f"Error de sintaxis: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except psycopg.OperationalError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error operacional: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except psycopg.Error as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error de driver: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + + execute_query = execute + + def raw( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + return self.execute(query, params, many=many) + + def _rollback_silent(self) -> None: + if self._connection is None or self._connection.closed: + return + try: + self._connection.rollback() + except psycopg.Error: # pragma: no cover - best effort + logger.warning("rollback failed", exc_info=True) + + # ------------------------------------------------------------------ + # Transacciones + # ------------------------------------------------------------------ + + @contextmanager + def transaction(self) -> Iterator[Database]: + if self._in_transaction: + raise error_codes.TRANSACTION_ALREADY_ACTIVE.build() + conn = self._conn + # En psycopg3 autocommit=True implica BEGIN explícito para abrir tx. + conn.execute("BEGIN") + self._in_transaction = True + try: + yield self + except BaseException: + self._rollback_silent() + raise + else: + conn.commit() + finally: + self._in_transaction = False + + # ------------------------------------------------------------------ + # DDL conveniencia + # ------------------------------------------------------------------ + + def create_database(self, name: str) -> Database: + try: + self.execute(f"CREATE DATABASE {quote_identifier(name)}") + except Exception as exc: # pragma: no cover - mensajes varían + if "already exists" not in str(exc): + raise + self.database = name + return self + + def use_database(self, name: str) -> Database: + """En Postgres no hay ``USE``. Se cierra y reconecta a la nueva DB.""" + self.close() + self.database = name + self.connect() + return self + + selected_database = use_database diff --git a/shiba/dialects/postgres/quoting.py b/shiba/dialects/postgres/quoting.py new file mode 100644 index 0000000..960e17b --- /dev/null +++ b/shiba/dialects/postgres/quoting.py @@ -0,0 +1,13 @@ +"""Quoting de identificadores Postgres (comillas dobles).""" +from __future__ import annotations + +from shiba.identifiers import validate_identifier + + +def quote_identifier(name: str) -> str: + """Valida y cita el identificador con comillas dobles. + + Soporta ``schema.table``, citando cada parte. + """ + validate_identifier(name) + return ".".join(f'"{part}"' for part in name.split(".")) diff --git a/shiba/dialects/postgres/schema.py b/shiba/dialects/postgres/schema.py new file mode 100644 index 0000000..4762b7a --- /dev/null +++ b/shiba/dialects/postgres/schema.py @@ -0,0 +1,26 @@ +"""Mapeo de tipos canónicos Shiba a SQL Postgres. + +El ``TableBuilder`` emite sintaxis estilo MySQL (``DATETIME``, +``BOOLEAN``, ``JSON``, ``BLOB``, etc.). Este módulo los traduce a la +forma idiomática de Postgres. +""" +from __future__ import annotations + +_MAPPING: dict[str, str] = { + "DATETIME": "TIMESTAMP", + "BLOB": "BYTEA", + "JSON": "JSONB", + "DOUBLE": "DOUBLE PRECISION", +} + + +def map_type(declared: str) -> str: + upper = declared.upper() + # Tipos con parámetros: ``VARCHAR(50)``, ``DECIMAL(10,2)``, etc. + if "(" in upper: + base, params = upper.split("(", 1) + base = base.strip() + if base in _MAPPING: + return f"{_MAPPING[base]}({params}" + return declared + return _MAPPING.get(upper, declared) diff --git a/tests/test_connect_factory.py b/tests/test_connect_factory.py new file mode 100644 index 0000000..309a7c1 --- /dev/null +++ b/tests/test_connect_factory.py @@ -0,0 +1,81 @@ +"""`shiba.connect(dsn)` resuelve el dialecto desde la URL.""" +from __future__ import annotations + +import pytest + +import shiba +from shiba import error_codes +from shiba.errors import ShibaError + + +def test_unsupported_scheme_raises() -> None: + with pytest.raises(ShibaError) as ei: + shiba.connect("oracle://u:p@h/db") + assert ei.value.code is error_codes.NOT_IMPLEMENTED + + +def test_mysql_dsn_picks_mysql_dialect(monkeypatch) -> None: + """Sin levantar conexión real, verificamos que el factory escogería MySQL.""" + constructed: dict[str, object] = {} + + class FakeMySQLDb: + def __init__(self, host, port, user, password, *, database=None): + constructed.update( + kind="mysql", + host=host, + port=port, + user=user, + password=password, + database=database, + ) + + monkeypatch.setattr(shiba, "Database", FakeMySQLDb) + cx = shiba.connect("mysql://alice:secret@db.host:3307/app") + assert constructed["kind"] == "mysql" + assert constructed["host"] == "db.host" + assert constructed["port"] == 3307 + assert constructed["user"] == "alice" + assert constructed["database"] == "app" + assert cx.dialect.name == "mysql" + + +def test_postgres_dsn_picks_postgres_dialect(monkeypatch) -> None: + constructed: dict[str, object] = {} + + class FakePgDb: + def __init__(self, host, port, user, password, *, database=None): + constructed.update( + kind="postgres", + host=host, + port=port, + user=user, + password=password, + database=database, + ) + + # Reemplazamos el import perezoso del driver Postgres por el fake. + import shiba.dialects.postgres.driver as pg_driver + + monkeypatch.setattr(pg_driver, "Database", FakePgDb) + cx = shiba.connect("postgres://bob:hunter2@pg.host/app") + assert constructed["kind"] == "postgres" + assert constructed["host"] == "pg.host" + assert constructed["port"] == 5432 # default + assert constructed["user"] == "bob" + assert constructed["database"] == "app" + assert cx.dialect.name == "postgres" + + +def test_postgresql_alias() -> None: + """``postgresql://`` y ``postgres://`` deberían comportarse igual.""" + import shiba.dialects.postgres.driver as pg_driver + + calls = [] + + class FakePgDb: + def __init__(self, *a, **kw): + calls.append((a, kw)) + + pg_driver.Database = FakePgDb # type: ignore[misc] + cx = shiba.connect("postgresql://x:y@h/d") + assert cx.dialect.name == "postgres" diff --git a/tests/test_postgres_dialect.py b/tests/test_postgres_dialect.py new file mode 100644 index 0000000..010eb7d --- /dev/null +++ b/tests/test_postgres_dialect.py @@ -0,0 +1,92 @@ +"""PostgresDialect emite SQL en la sintaxis idiomática del motor.""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.core.query_builder import QueryBuilder +from shiba.core.table_builder import TableBuilder +from shiba.dialects.postgres.dialect import PostgresDialect +from shiba.errors import MissingDataError, SchemaError + + +@pytest.fixture +def pg() -> PostgresDialect: + return PostgresDialect() + + +def test_quotes_with_double_quotes(pg) -> None: + assert pg.quote_identifier("users") == '"users"' + assert pg.quote_identifier("schema.table") == '"schema"."table"' + + +def test_quote_rejects_injection(pg) -> None: + with pytest.raises(SchemaError) as ei: + pg.quote_identifier('users"; DROP TABLE x; --') + assert ei.value.code is error_codes.INVALID_IDENTIFIER + + +def test_select_uses_double_quotes(fake_db, pg) -> None: + QueryBuilder(fake_db, "users", dialect=pg).where("name", "John").get() + sql, params, _ = fake_db.last_call + assert sql.startswith('SELECT * FROM "users"') + assert 'WHERE "name" = %s' in sql + assert params == ("John",) + + +def test_create_table_uses_identity_pk(fake_db, pg) -> None: + sql = ( + TableBuilder(fake_db, "users", dialect=pg) + .increments("id", primary_key=True) + .string("name") + .json("settings") + .to_sql() + ) + assert 'CREATE TABLE IF NOT EXISTS "users"' in sql + assert '"id" INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY' in sql + assert '"name" VARCHAR(255)' in sql + assert '"settings" JSONB' in sql # JSON → JSONB + + +def test_datetime_maps_to_timestamp(fake_db, pg) -> None: + sql = TableBuilder(fake_db, "t", dialect=pg).datetime("created_at").to_sql() + assert '"created_at" TIMESTAMP' in sql + assert "DATETIME" not in sql + + +def test_blob_maps_to_bytea(fake_db, pg) -> None: + sql = TableBuilder(fake_db, "t", dialect=pg).binary("data").to_sql() + assert '"data" BYTEA' in sql + + +def test_upsert_emits_on_conflict(fake_db, pg) -> None: + QueryBuilder(fake_db, "users", dialect=pg).upsert( + {"id": 1, "name": "X"}, on=["id"] + ) + sql, params, _ = fake_db.last_call + assert sql == ( + 'INSERT INTO "users" ("id", "name") VALUES (%s, %s) ' + 'ON CONFLICT ("id") DO UPDATE SET "id" = EXCLUDED."id", "name" = EXCLUDED."name"' + ) + assert params == (1, "X") + + +def test_upsert_without_on_in_postgres_raises(fake_db, pg) -> None: + with pytest.raises(MissingDataError) as ei: + QueryBuilder(fake_db, "users", dialect=pg).upsert({"id": 1, "name": "X"}) + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +def test_upsert_empty_update_does_nothing(fake_db, pg) -> None: + QueryBuilder(fake_db, "users", dialect=pg).upsert( + {"id": 1, "name": "X"}, update=[], on=["id"] + ) + sql, _, _ = fake_db.last_call + assert 'ON CONFLICT ("id") DO NOTHING' in sql + + +def test_mysql_upsert_still_works_without_on(fake_db, dialect) -> None: + """`on` es opcional en MySQL (lo detecta la PK).""" + QueryBuilder(fake_db, "users", dialect=dialect).upsert({"id": 1, "name": "X"}) + sql, _, _ = fake_db.last_call + assert "ON DUPLICATE KEY UPDATE" in sql