Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ classifiers = [
dependencies = ["pymysql>=1.1"]

[project.optional-dependencies]
postgres = ["psycopg[binary]>=3.1"]
dev = [
"pytest>=8",
"pytest-cov>=5",
"ruff>=0.6",
"mypy>=1.10",
"psycopg[binary]>=3.1",
]

[project.urls]
Expand Down
110 changes: 84 additions & 26 deletions shiba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
"""Shiba — librería ligera para hablar con bases de datos relacionales.

Punto de entrada público:

.. code-block:: python

import shiba as s
import shiba

with s.ShibaConnection(host="localhost", port=3306,
user="u", password="p") as cx:
cx.create_database("my_db")
cx.use_database("my_db")
cx.create_table("users") \\
.increments("id", primary_key=True) \\
.string("name") \\
.build()
# Forma 1 — DSN explícito (recomendado para multi-dialecto):
cx = shiba.connect("mysql://user:pass@localhost:3306/my_db")
cx = shiba.connect("postgres://user:pass@localhost:5432/my_db")

cx.table("users").insert({"name": "John"})
rows = cx.table("users").where("name", "John").get()
# Forma 2 — construcción directa MySQL (legacy):
cx = shiba.ShibaConnection(host="localhost", port=3306,
user="u", password="p")
"""
from __future__ import annotations

from contextlib import AbstractContextManager
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

from shiba import error_codes
from shiba.core.query_builder import QueryBuilder
from shiba.core.table_builder import TableBuilder
from shiba.dialects.base import Dialect
from shiba.dialects.mysql import Database, MySQLDialect
from shiba.errors import (
ConnectionError,
Expand All @@ -42,17 +38,36 @@


class ShibaConnection:
"""Fachada de alto nivel sobre un :class:`Database` MySQL."""
"""Fachada de alto nivel agnóstica de dialecto.

Acepta dos formas de construcción:

* Legacy MySQL: ``ShibaConnection(host, port, user, password)``.
* Inyectada: ``ShibaConnection(db=Database(...), dialect=Dialect(...))``
(la usa :func:`connect`).
"""

def __init__(
self,
host: str,
port: int,
user: str,
password: str,
host: str | None = None,
port: int | None = None,
user: str | None = None,
password: str | None = None,
*,
database: str | None = None,
db: Any = None,
dialect: Dialect | None = None,
) -> None:
if db is not None and dialect is not None:
self.dialect: Dialect = dialect
self.db: Any = db
return

if host is None or port is None or user is None or password is None:
error_codes.MISSING_REQUIRED_DATA.raise_(
"ShibaConnection requiere host/port/user/password o "
"db+dialect inyectados."
)
self.dialect = MySQLDialect()
self.db = Database(host, port, user, password, database=database)

Expand All @@ -78,10 +93,10 @@ def close(self) -> None:
# API pública
# ------------------------------------------------------------------

def create_database(self, database: str) -> Database:
def create_database(self, database: str) -> Any:
return self.db.create_database(database)

def use_database(self, database: str) -> Database:
def use_database(self, database: str) -> Any:
return self.db.use_database(database)

def create_table(self, table_name: str) -> TableBuilder:
Expand All @@ -90,9 +105,10 @@ def create_table(self, table_name: str) -> TableBuilder:
def table(self, table_name: str) -> QueryBuilder:
return QueryBuilder(self.db, table_name, dialect=self.dialect)

def transaction(self) -> AbstractContextManager[Database]:
"""Context manager transaccional. Ver :meth:`Database.transaction`."""
return self.db.transaction()
def transaction(self) -> AbstractContextManager[Any]:
"""Context manager transaccional."""
cm: AbstractContextManager[Any] = self.db.transaction()
return cm

def raw(
self,
Expand All @@ -101,13 +117,54 @@ def raw(
*,
many: bool = False,
) -> list[dict[str, object]]:
"""Escape hatch — ver :meth:`Database.raw`."""
return self.db.raw(query, params, many=many)
rows: list[dict[str, object]] = self.db.raw(query, params, many=many)
return rows


# ---------------------------------------------------------------------------
# Factory connect(dsn)
# ---------------------------------------------------------------------------


_DEFAULT_PORTS = {"mysql": 3306, "postgres": 5432, "postgresql": 5432}


def connect(dsn: str) -> ShibaConnection:
"""Construye una :class:`ShibaConnection` desde un DSN tipo URL.

Schemes soportados:

* ``mysql://user:pass@host:port/dbname``
* ``postgres://user:pass@host:port/dbname`` (alias: ``postgresql://``)
"""
parsed = urlparse(dsn)
scheme = parsed.scheme.lower()
if scheme not in _DEFAULT_PORTS:
error_codes.NOT_IMPLEMENTED.raise_(
f"DSN scheme '{scheme}' no soportado. Usa: {sorted(_DEFAULT_PORTS)}."
)
host = parsed.hostname or "localhost"
port = parsed.port or _DEFAULT_PORTS[scheme]
user = parsed.username or ""
password = parsed.password or ""
database = parsed.path.lstrip("/") or None

if scheme == "mysql":
db: Any = Database(host, port, user, password, database=database)
return ShibaConnection(db=db, dialect=MySQLDialect())

# postgres / postgresql
from shiba.dialects.postgres import PostgresDialect
from shiba.dialects.postgres.driver import Database as PgDatabase

db = PgDatabase(host, port, user, password, database=database)
return ShibaConnection(db=db, dialect=PostgresDialect())


__all__ = [
"ConnectionError",
"Database",
"Dialect",
"IntegrityError",
"MissingDataError",
"Model",
Expand All @@ -118,6 +175,7 @@ def raw(
"ShibaConnection",
"ShibaError",
"TableBuilder",
"connect",
"error_codes",
"fields",
"set_default_connection",
Expand Down
12 changes: 8 additions & 4 deletions shiba/core/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,14 @@ def upsert(
data: dict[str, Any],
*,
update: list[str] | None = None,
on: list[str] | None = None,
) -> list[dict[str, Any]]:
"""INSERT con resolución de conflicto.

En MySQL se emite ``ON DUPLICATE KEY UPDATE``. El parámetro
``update`` indica qué columnas pisar (por defecto todas excepto
las que sean clave). El dialecto adapta la cláusula.
:param data: columnas → valores.
:param update: columnas a pisar en conflicto (default todas).
:param on: columnas del conflicto. Requerido por Postgres,
opcional en MySQL (lo detecta por la PK).
"""
if not data:
raise error_codes.MISSING_REQUIRED_DATA.build("upsert(): dict vacío.")
Expand All @@ -515,7 +517,9 @@ def upsert(
update_cols = update if update is not None else list(data.keys())
for col in update_cols:
validate_identifier(col, kind="column")
update_sql = self.dialect.compile_upsert_update(update_cols)
for col in on or []:
validate_identifier(col, kind="column")
update_sql = self.dialect.compile_upsert_update(update_cols, on)
table = self.dialect.quote_identifier(self.table_name)
query = (
f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders}) {update_sql}"
Expand Down
9 changes: 5 additions & 4 deletions shiba/core/table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _col(self, column_name: str, sql_type: str) -> TableBuilder:

def increments(self, column_name: str = "id", primary_key: bool = False) -> TableBuilder:
validate_identifier(column_name, kind="column")
suffix = " PRIMARY KEY" if primary_key else ""
self._append_column(
f"{self.dialect.quote_identifier(column_name)} INT AUTO_INCREMENT{suffix}"
)
col_quoted = self.dialect.quote_identifier(column_name)
if primary_key:
self._append_column(self.dialect.compile_auto_increment_pk(col_quoted))
else:
self._append_column(f"{col_quoted} INT AUTO_INCREMENT")
return self

def integer(self, column_name: str, length: int | None = None) -> TableBuilder:
Expand Down
18 changes: 16 additions & 2 deletions shiba/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,23 @@ def render_limit(self, limit: int | None, offset: int | None) -> str:
return " ".join(parts)

@abstractmethod
def compile_upsert_update(self, update_columns: list[str]) -> str:
def compile_upsert_update(
self,
update_columns: list[str],
conflict_columns: list[str] | None = None,
) -> str:
"""Cláusula de resolución de conflicto para ``upsert``.

MySQL → ``ON DUPLICATE KEY UPDATE col = VALUES(col), ...``
Postgres/SQLite → ``ON CONFLICT (...) DO UPDATE SET ...``.
(``conflict_columns`` se ignora; lo detecta por la PK).
Postgres/SQLite → ``ON CONFLICT (col, ...) DO UPDATE SET col = EXCLUDED.col``
(``conflict_columns`` obligatorio).
"""

def compile_auto_increment_pk(self, column_quoted: str) -> str:
"""Declaración inline de PK auto-incremental.

Default MySQL-ish. Postgres lo override con
``GENERATED ALWAYS AS IDENTITY``.
"""
return f"{column_quoted} INT AUTO_INCREMENT PRIMARY KEY"
6 changes: 5 additions & 1 deletion shiba/dialects/mysql/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def quote_identifier(self, name: str) -> str:
def map_type(self, declared: str) -> str:
return _map_type(declared)

def compile_upsert_update(self, update_columns: list[str]) -> str:
def compile_upsert_update(
self,
update_columns: list[str],
conflict_columns: list[str] | None = None,
) -> str:
if not update_columns:
return ""
parts = [f"{_qi(c)} = VALUES({_qi(c)})" for c in update_columns]
Expand Down
19 changes: 19 additions & 0 deletions shiba/dialects/postgres/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Dialecto PostgreSQL."""
from typing import Any

from shiba.dialects.postgres.dialect import PostgresDialect

__all__ = ["PostgresDialect"]


def _import_driver() -> Any:
"""Import perezoso de ``Database`` (requiere ``psycopg``)."""
from shiba.dialects.postgres.driver import Database

return Database


def __getattr__(name: str) -> Any:
if name == "Database":
return _import_driver()
raise AttributeError(name)
39 changes: 39 additions & 0 deletions shiba/dialects/postgres/dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Implementación :class:`~shiba.dialects.base.Dialect` para PostgreSQL."""
from __future__ import annotations

from shiba import error_codes
from shiba.dialects.base import Dialect
from shiba.dialects.postgres.quoting import quote_identifier as _qi
from shiba.dialects.postgres.schema import map_type as _map_type


class PostgresDialect(Dialect):
"""Doble comillas, placeholder ``%s`` (psycopg), ``ON CONFLICT``."""

name = "postgres"
placeholder = "%s"

def quote_identifier(self, name: str) -> str:
return _qi(name)

def map_type(self, declared: str) -> str:
return _map_type(declared)

def compile_auto_increment_pk(self, column_quoted: str) -> str:
return f"{column_quoted} INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY"

def compile_upsert_update(
self,
update_columns: list[str],
conflict_columns: list[str] | None = None,
) -> str:
if not conflict_columns:
error_codes.MISSING_REQUIRED_DATA.raise_(
"upsert() en Postgres requiere el parámetro `on=[col, ...]` "
"para identificar el conflicto."
)
target = ", ".join(_qi(c) for c in conflict_columns)
if not update_columns:
return f"ON CONFLICT ({target}) DO NOTHING"
sets = ", ".join(f"{_qi(c)} = EXCLUDED.{_qi(c)}" for c in update_columns)
return f"ON CONFLICT ({target}) DO UPDATE SET {sets}"
Loading
Loading