diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0b265f4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,41 @@ +name: ci + +on: + push: + branches: [main] + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint-type-test: + name: ${{ matrix.python-version }} · lint · type · test + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Ruff + run: ruff check . + + - name: Mypy + run: mypy + + - name: Pytest + run: pytest -q diff --git a/MANINEFT.in b/MANIFEST.in similarity index 100% rename from MANINEFT.in rename to MANIFEST.in diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..358e2ce --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,71 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "shiba_mysql" +version = "2.0.0" +description = "Cliente y query builder ligero para bases de datos relacionales (MySQL hoy, multi-dialecto en camino)." +readme = "README.md" +requires-python = ">=3.10" +license = { text = "MIT" } +authors = [{ name = "Rodrigo Pino", email = "ro.pinoo18@gmail.com" }] +keywords = ["mysql", "database", "sql", "query-builder", "orm"] +classifiers = [ + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Database", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] +dependencies = ["pymysql>=1.1"] + +[project.optional-dependencies] +dev = [ + "pytest>=8", + "pytest-cov>=5", + "ruff>=0.6", + "mypy>=1.10", +] + +[project.urls] +Homepage = "https://github.com/ShibaRoPinoo/Shiba-Py-Mysql" +Issues = "https://github.com/ShibaRoPinoo/Shiba-Py-Mysql/issues" + +[tool.setuptools.packages.find] +include = ["shiba*", "shibamysql*"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "B", "UP", "SIM", "RUF", "ANN", "A"] +ignore = [ + "ANN401", # permitir Any explícito + "A004", # sombreamos ConnectionError a propósito (parte de la API pública) + "UP038", # `isinstance(x, X | Y)` rompe en py3.10 sin from __future__ import +] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["ANN", "D"] + +[tool.mypy] +python_version = "3.10" +strict = true +warn_unreachable = true +show_error_codes = true +files = ["shiba", "shibamysql"] + +[[tool.mypy.overrides]] +module = ["pymysql.*"] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-ra --strict-markers" diff --git a/setup.py b/setup.py deleted file mode 100644 index fa52923..0000000 --- a/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -from setuptools import setup, find_packages - -classifiers = [ - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Database', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', -] - -setup( - name='shiba_mysql', - version='1.2.0', - description='A library to interact with MySQL', - long_description=open('README.md').read() + '\n\n' + open('CHANGELOG.txt').read(), - long_description_content_type='text/markdown', - author='Rodrigo Pino', - license='MIT', - classifiers=classifiers, - keywords=['mysql', 'database', 'SQL', 'data access', 'ORM'], - author_email='ro.pinoo18@gmail.com', - url='https://github.com/ShibaRoPinoo/Shiba-Py-Mysql', - packages=find_packages(), - install_requires=['pymysql'], -) diff --git a/shiba/__init__.py b/shiba/__init__.py new file mode 100644 index 0000000..67abc73 --- /dev/null +++ b/shiba/__init__.py @@ -0,0 +1,110 @@ +"""Shiba — librería ligera para hablar con bases de datos relacionales. + +Punto de entrada público: + +.. code-block:: python + + import shiba as s + + with s.ShibaConnection(host="localhost", port=3306, + user="u", password="p") as cx: + cx.create_database("my_db") + cx.use_database("my_db") + cx.create_table("users") \\ + .increments("id", primary_key=True) \\ + .string("name") \\ + .build() + + cx.table("users").insert({"name": "John"}) + rows = cx.table("users").where("name", "John").get() +""" +from __future__ import annotations + +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING + +from shiba import error_codes +from shiba.core.query_builder import QueryBuilder +from shiba.core.table_builder import TableBuilder +from shiba.dialects.mysql import Database, MySQLDialect +from shiba.errors import ( + ConnectionError, + IntegrityError, + MissingDataError, + QueryError, + SchemaError, + ShibaError, +) + +if TYPE_CHECKING: + from types import TracebackType + + +class ShibaConnection: + """Fachada de alto nivel sobre un :class:`Database` MySQL.""" + + def __init__( + self, + host: str, + port: int, + user: str, + password: str, + *, + database: str | None = None, + ) -> None: + self.dialect = MySQLDialect() + self.db = Database(host, port, user, password, database=database) + + # ------------------------------------------------------------------ + # Context manager + # ------------------------------------------------------------------ + + def __enter__(self) -> ShibaConnection: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + self.db.close() + + # ------------------------------------------------------------------ + # API pública + # ------------------------------------------------------------------ + + def create_database(self, database: str) -> Database: + return self.db.create_database(database) + + def use_database(self, database: str) -> Database: + return self.db.use_database(database) + + def create_table(self, table_name: str) -> TableBuilder: + return TableBuilder(self.db, table_name, dialect=self.dialect) + + def table(self, table_name: str) -> QueryBuilder: + return QueryBuilder(self.db, table_name, dialect=self.dialect) + + def transaction(self) -> AbstractContextManager[Database]: + """Context manager transaccional. Ver :meth:`Database.transaction`.""" + return self.db.transaction() + + +__all__ = [ + "ConnectionError", + "Database", + "IntegrityError", + "MissingDataError", + "MySQLDialect", + "QueryBuilder", + "QueryError", + "SchemaError", + "ShibaConnection", + "ShibaError", + "TableBuilder", + "error_codes", +] diff --git a/shiba/core/__init__.py b/shiba/core/__init__.py new file mode 100644 index 0000000..32acdbd --- /dev/null +++ b/shiba/core/__init__.py @@ -0,0 +1 @@ +"""Núcleo agnóstico de Shiba (no depende de ningún driver concreto).""" diff --git a/shiba/core/query_builder.py b/shiba/core/query_builder.py new file mode 100644 index 0000000..9a6eeb4 --- /dev/null +++ b/shiba/core/query_builder.py @@ -0,0 +1,293 @@ +"""QueryBuilder fluido y agnóstico de dialecto. + +Construye SQL con **placeholders parametrizados** para todos los valores +y delega quoting de identificadores al :class:`~shiba.dialects.base.Dialect`. +Esto cierra el agujero de SQL injection que tenía la v1.x. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from shiba import error_codes +from shiba.identifiers import validate_identifier, validate_operator + +if TYPE_CHECKING: + from shiba.dialects.base import Dialect + from shiba.dialects.mysql.driver import Database + + +_VALID_JOIN_TYPES = frozenset({"JOIN", "LEFT JOIN", "RIGHT JOIN", "INNER JOIN", "CROSS JOIN"}) + + +class QueryBuilder: + """API fluida para consultas. Inmutable-ish: cada método retorna ``self``.""" + + def __init__(self, db: Database, table_name: str, *, dialect: Dialect) -> None: + self.db = db + self.dialect = dialect + self.table_name = validate_identifier(table_name, kind="table") + self._selected: list[str] = [] + self._joins: list[str] = [] + self._where: list[tuple[str, str, Any]] = [] + self._order_by: list[tuple[str, str]] = [] + self._limit: int | None = None + self._offset: int | None = None + + # ------------------------------------------------------------------ + # SELECT + # ------------------------------------------------------------------ + + def select(self, *columns: str) -> QueryBuilder: + if not columns: + raise error_codes.MISSING_REQUIRED_DATA.build( + "select() requiere al menos una columna." + ) + for col in columns: + validate_identifier(col, kind="column") + self._selected.extend(columns) + return self + + # ------------------------------------------------------------------ + # JOIN + # ------------------------------------------------------------------ + + def _add_join( + self, + kind: str, + table_name: str, + column1: str, + operator: str, + column2: str, + ) -> QueryBuilder: + if kind not in _VALID_JOIN_TYPES: + raise error_codes.NOT_IMPLEMENTED.build(f"JOIN type no soportado: {kind}") + if not all([table_name, column1, operator, column2]): + raise error_codes.MISSING_REQUIRED_DATA.build( + f"Faltan parámetros requeridos para {kind}." + ) + t = self.dialect.quote_identifier(table_name) + c1 = self.dialect.quote_identifier(column1) + c2 = self.dialect.quote_identifier(column2) + op = validate_operator(operator) + self._joins.append(f"{kind} {t} ON {c1} {op} {c2}") + return self + + def join(self, table_name: str, column1: str, operator: str, column2: str) -> QueryBuilder: + return self._add_join("JOIN", table_name, column1, operator, column2) + + def inner_join( + self, table_name: str, column1: str, operator: str, column2: str + ) -> QueryBuilder: + return self._add_join("INNER JOIN", table_name, column1, operator, column2) + + def left_join( + self, table_name: str, column1: str, operator: str, column2: str + ) -> QueryBuilder: + return self._add_join("LEFT JOIN", table_name, column1, operator, column2) + + def right_join( + self, table_name: str, column1: str, operator: str, column2: str + ) -> QueryBuilder: + return self._add_join("RIGHT JOIN", table_name, column1, operator, column2) + + def cross_join(self, table_name: str) -> QueryBuilder: + if not table_name: + raise error_codes.MISSING_REQUIRED_DATA.build("CROSS JOIN requiere tabla.") + self._joins.append(f"CROSS JOIN {self.dialect.quote_identifier(table_name)}") + return self + + # ------------------------------------------------------------------ + # WHERE + # ------------------------------------------------------------------ + + def where(self, *args: Any) -> QueryBuilder: + """``where(col, val)`` o ``where(col, op, val)`` o ``where([[...], [...]])``.""" + # Forma con lista de condiciones. + if len(args) == 1 and isinstance(args[0], list): + return self._where_many(args[0]) + + if len(args) == 2: + column, value = args + operator = "=" + elif len(args) == 3: + column, operator, value = args + else: + raise error_codes.INVALID_QUERY_PARAMS.build( + f"where() acepta 2 o 3 argumentos, recibió {len(args)}." + ) + + validate_identifier(column, kind="column") + op = validate_operator(operator) + self._where.append((column, op, value)) + return self + + def _where_many(self, conditions: list[Any]) -> QueryBuilder: + for cond in conditions: + if not isinstance(cond, (list, tuple)): + raise error_codes.INVALID_QUERY_PARAMS.build( + f"cada condición de where() debe ser list/tuple, recibió {type(cond).__name__}." + ) + if len(cond) == 2: + self.where(cond[0], cond[1]) + elif len(cond) == 3: + self.where(cond[0], cond[1], cond[2]) + else: + raise error_codes.INVALID_QUERY_PARAMS.build( + f"condición con {len(cond)} elementos; se esperaban 2 o 3." + ) + return self + + # ------------------------------------------------------------------ + # ORDER / LIMIT + # ------------------------------------------------------------------ + + def order_by(self, column: str, direction: str = "ASC") -> QueryBuilder: + validate_identifier(column, kind="column") + d = direction.strip().upper() + if d not in {"ASC", "DESC"}: + raise error_codes.INVALID_QUERY_PARAMS.build( + f"dirección inválida: {direction!r}. Use ASC o DESC." + ) + self._order_by.append((column, d)) + return self + + def limit(self, n: int) -> QueryBuilder: + self._limit = int(n) + return self + + def offset(self, n: int) -> QueryBuilder: + self._offset = int(n) + return self + + # ------------------------------------------------------------------ + # Compilación de WHERE + # ------------------------------------------------------------------ + + def _compile_where(self) -> tuple[str, list[Any]]: + if not self._where: + return "", [] + parts: list[str] = [] + params: list[Any] = [] + for column, op, value in self._where: + col_sql = self.dialect.quote_identifier(column) + if op in {"IN", "NOT IN"}: + if not isinstance(value, (list, tuple)) or not value: + raise error_codes.INVALID_QUERY_PARAMS.build( + f"{op} requiere lista/tupla no vacía." + ) + placeholders = ", ".join([self.dialect.placeholder] * len(value)) + parts.append(f"{col_sql} {op} ({placeholders})") + params.extend(value) + elif op in {"IS", "IS NOT"} and value is None: + parts.append(f"{col_sql} {op} NULL") + else: + parts.append(f"{col_sql} {op} {self.dialect.placeholder}") + params.append(value) + return "WHERE " + " AND ".join(parts), params + + # ------------------------------------------------------------------ + # SELECT execution + # ------------------------------------------------------------------ + + def get(self) -> list[dict[str, Any]]: + select_clause = "*" + if self._selected: + select_clause = ", ".join(self.dialect.quote_identifier(c) for c in self._selected) + + joins = " ".join(self._joins) + where_sql, params = self._compile_where() + + order_sql = "" + if self._order_by: + order_sql = "ORDER BY " + ", ".join( + f"{self.dialect.quote_identifier(c)} {d}" for c, d in self._order_by + ) + limit_sql = self.dialect.render_limit(self._limit, self._offset) + + table = self.dialect.quote_identifier(self.table_name) + parts = [f"SELECT {select_clause} FROM {table}", joins, where_sql, order_sql, limit_sql] + query = " ".join(p for p in parts if p).strip() + return self.db.execute(query, tuple(params) if params else None) + + def first(self) -> dict[str, Any] | None: + self._limit = 1 + rows = self.get() + return rows[0] if rows else None + + def count(self, column: str = "*") -> int: + col = "*" if column == "*" else self.dialect.quote_identifier(column) + joins = " ".join(self._joins) + where_sql, params = self._compile_where() + table = self.dialect.quote_identifier(self.table_name) + query = f"SELECT COUNT({col}) AS cnt FROM {table} {joins} {where_sql}".strip() + rows = self.db.execute(query, tuple(params) if params else None) + return int(rows[0]["cnt"]) if rows else 0 + + # ------------------------------------------------------------------ + # INSERT / UPDATE / DELETE + # ------------------------------------------------------------------ + + def insert(self, data: dict[str, Any] | list[dict[str, Any]]) -> list[dict[str, Any]]: + if isinstance(data, list): + return self._insert_many(data) + if isinstance(data, dict): + return self._insert_single(data) + raise error_codes.INVALID_DATA_FORMAT.build( + f"insert() acepta dict o list[dict], recibió {type(data).__name__}." + ) + + def _insert_single(self, data: dict[str, Any]) -> list[dict[str, Any]]: + if not data: + raise error_codes.MISSING_REQUIRED_DATA.build("insert(): dict vacío.") + cols = [validate_identifier(k, kind="column") for k in data] + cols_sql = ", ".join(self.dialect.quote_identifier(c) for c in cols) + placeholders = ", ".join([self.dialect.placeholder] * len(cols)) + table = self.dialect.quote_identifier(self.table_name) + query = f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders})" + return self.db.execute(query, tuple(data.values())) + + def _insert_many(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not rows: + return [] + first_keys = list(rows[0].keys()) + if not first_keys: + raise error_codes.MISSING_REQUIRED_DATA.build("insert(): filas vacías.") + cols = [validate_identifier(k, kind="column") for k in first_keys] + # Forzamos que todas las filas tengan las mismas claves y en el mismo orden. + values: list[tuple[Any, ...]] = [] + for row in rows: + if list(row.keys()) != first_keys: + raise error_codes.INVALID_DATA_FORMAT.build( + "insert(): todas las filas deben tener las mismas claves en el mismo orden." + ) + values.append(tuple(row[k] for k in first_keys)) + cols_sql = ", ".join(self.dialect.quote_identifier(c) for c in cols) + placeholders = ", ".join([self.dialect.placeholder] * len(cols)) + table = self.dialect.quote_identifier(self.table_name) + query = f"INSERT INTO {table} ({cols_sql}) VALUES ({placeholders})" + return self.db.execute(query, values, many=True) + + def update(self, data: dict[str, Any]) -> list[dict[str, Any]]: + if not data: + raise error_codes.MISSING_REQUIRED_DATA.build("update() requiere datos.") + set_parts: list[str] = [] + params: list[Any] = [] + for col, val in data.items(): + validate_identifier(col, kind="column") + set_parts.append(f"{self.dialect.quote_identifier(col)} = {self.dialect.placeholder}") + params.append(val) + where_sql, where_params = self._compile_where() + params.extend(where_params) + table = self.dialect.quote_identifier(self.table_name) + query = f"UPDATE {table} SET {', '.join(set_parts)} {where_sql}".strip() + return self.db.execute(query, tuple(params)) + + def delete(self) -> list[dict[str, Any]]: + where_sql, params = self._compile_where() + if not where_sql: + raise error_codes.MISSING_REQUIRED_DATA.build( + "delete() sin WHERE no está permitido. Usa truncate() si quieres vaciar." + ) + table = self.dialect.quote_identifier(self.table_name) + query = f"DELETE FROM {table} {where_sql}".strip() + return self.db.execute(query, tuple(params)) diff --git a/shiba/core/sql.py b/shiba/core/sql.py new file mode 100644 index 0000000..ca9499b --- /dev/null +++ b/shiba/core/sql.py @@ -0,0 +1,21 @@ +"""Representación neutra de una sentencia SQL ya compilada.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class SQL: + """Una sentencia SQL lista para ejecutar. + + :param text: SQL con placeholders del dialecto correspondiente. + :param params: parámetros (tupla para una sola ejecución, lista de + tuplas si ``many`` es True). + :param many: si True, ``params`` es un iterable de filas para + ``executemany``. + """ + + text: str + params: tuple[Any, ...] | list[tuple[Any, ...]] = field(default_factory=tuple) + many: bool = False diff --git a/shiba/core/table_builder.py b/shiba/core/table_builder.py new file mode 100644 index 0000000..b2ec6d3 --- /dev/null +++ b/shiba/core/table_builder.py @@ -0,0 +1,190 @@ +"""TableBuilder fluido y agnóstico de dialecto. + +Acumula declaraciones de columnas y constraints; ``build()`` compila el +``CREATE TABLE`` final delegando quoting y mapeo de tipos al +:class:`~shiba.dialects.base.Dialect`. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from shiba import error_codes +from shiba.identifiers import validate_identifier + +if TYPE_CHECKING: + from shiba.dialects.base import Dialect + from shiba.dialects.mysql.driver import Database + + +def _escape_enum_choice(choice: str) -> str: + """Escapa una opción de ``ENUM`` duplicando comillas simples.""" + if not isinstance(choice, str): + raise error_codes.INVALID_DATA_FORMAT.build( + f"ENUM choice debe ser str, recibió {type(choice).__name__}." + ) + return "'" + choice.replace("'", "''") + "'" + + +class TableBuilder: + """Acumula columnas y emite ``CREATE TABLE IF NOT EXISTS``.""" + + def __init__(self, db: Database, table_name: str, *, dialect: Dialect) -> None: + self.db = db + self.dialect = dialect + self.table_name = validate_identifier(table_name, kind="table") + self._columns: list[str] = [] + + # ------------------------------------------------------------------ + # Modificadores de la última columna + # ------------------------------------------------------------------ + + def _require_columns(self) -> None: + if not self._columns: + raise error_codes.NO_COLUMNS_DEFINED.build() + + def _append_column(self, declaration: str) -> TableBuilder: + self._columns.append(declaration) + return self + + def _amend_last(self, fragment: str) -> TableBuilder: + self._require_columns() + self._columns[-1] = f"{self._columns[-1]} {fragment}" + return self + + def primary(self) -> TableBuilder: + return self._amend_last("PRIMARY KEY") + + def unique(self) -> TableBuilder: + return self._amend_last("UNIQUE") + + def nullable(self) -> TableBuilder: + return self._amend_last("NULL") + + def not_nullable(self) -> TableBuilder: + return self._amend_last("NOT NULL") + + def default(self, value: str | int | float | bool | None) -> TableBuilder: + if value is None: + return self._amend_last("DEFAULT NULL") + if isinstance(value, bool): + return self._amend_last(f"DEFAULT {1 if value else 0}") + if isinstance(value, (int, float)): + return self._amend_last(f"DEFAULT {value}") + return self._amend_last("DEFAULT " + _escape_enum_choice(value)) + + def foreign( + self, + foreign_name: str | None = None, + table_name: str | None = None, + column_name: str | None = None, + ) -> TableBuilder: + self._require_columns() + if not foreign_name or not table_name or not column_name: + raise error_codes.MISSING_REQUIRED_DATA.build( + "foreign() requiere foreign_name, table_name y column_name." + ) + validate_identifier(foreign_name, kind="constraint") + last = self._columns[-1] + current_col = last.split()[0].strip("`\"[]") + qcol = self.dialect.quote_identifier(current_col) + qtable = self.dialect.quote_identifier(table_name) + qref = self.dialect.quote_identifier(column_name) + qname = self.dialect.quote_identifier(foreign_name) + self._columns[-1] = ( + f"{last}, CONSTRAINT {qname} FOREIGN KEY ({qcol}) REFERENCES {qtable}({qref})" + ) + return self + + # ------------------------------------------------------------------ + # Tipos + # ------------------------------------------------------------------ + + def _col(self, column_name: str, sql_type: str) -> TableBuilder: + validate_identifier(column_name, kind="column") + return self._append_column( + f"{self.dialect.quote_identifier(column_name)} {self.dialect.map_type(sql_type)}" + ) + + def increments(self, column_name: str = "id", primary_key: bool = False) -> TableBuilder: + validate_identifier(column_name, kind="column") + suffix = " PRIMARY KEY" if primary_key else "" + self._append_column( + f"{self.dialect.quote_identifier(column_name)} INT AUTO_INCREMENT{suffix}" + ) + return self + + def integer(self, column_name: str, length: int | None = None) -> TableBuilder: + sql_type = "INT" if length is None else f"INT({int(length)})" + return self._col(column_name, sql_type) + + def big_integer(self, column_name: str) -> TableBuilder: + return self._col(column_name, "BIGINT") + + def small_integer(self, column_name: str) -> TableBuilder: + return self._col(column_name, "SMALLINT") + + def tiny_integer(self, column_name: str) -> TableBuilder: + return self._col(column_name, "TINYINT") + + def string(self, column_name: str, length: int = 255) -> TableBuilder: + return self._col(column_name, f"VARCHAR({int(length)})") + + def text(self, column_name: str) -> TableBuilder: + return self._col(column_name, "TEXT") + + def char(self, column_name: str, length: int = 1) -> TableBuilder: + return self._col(column_name, f"CHAR({int(length)})") + + def date(self, column_name: str) -> TableBuilder: + return self._col(column_name, "DATE") + + def datetime(self, column_name: str) -> TableBuilder: + return self._col(column_name, "DATETIME") + + def time(self, column_name: str) -> TableBuilder: + return self._col(column_name, "TIME") + + def timestamp(self, column_name: str) -> TableBuilder: + return self._col(column_name, "TIMESTAMP") + + def decimal(self, column_name: str, precision: int = 10, scale: int = 2) -> TableBuilder: + return self._col(column_name, f"DECIMAL({int(precision)}, {int(scale)})") + + def floats(self, column_name: str, precision: int = 10, scale: int = 2) -> TableBuilder: + return self._col(column_name, f"FLOAT({int(precision)}, {int(scale)})") + + def boolean(self, column_name: str) -> TableBuilder: + return self._col(column_name, "BOOLEAN") + + def binary(self, column_name: str, length: int | None = None) -> TableBuilder: + sql_type = "BLOB" if length is None else f"BLOB({int(length)})" + return self._col(column_name, sql_type) + + def json(self, column_name: str) -> TableBuilder: + return self._col(column_name, "JSON") + + def enum(self, column_name: str, choices: list[str]) -> TableBuilder: + validate_identifier(column_name, kind="column") + if not choices: + raise error_codes.MISSING_REQUIRED_DATA.build("enum() requiere choices.") + choices_sql = ", ".join(_escape_enum_choice(c) for c in choices) + return self._append_column( + f"{self.dialect.quote_identifier(column_name)} ENUM({choices_sql})" + ) + + # ------------------------------------------------------------------ + # Build + # ------------------------------------------------------------------ + + def build(self) -> TableBuilder: + self._require_columns() + table = self.dialect.quote_identifier(self.table_name) + query = f"CREATE TABLE IF NOT EXISTS {table} ({', '.join(self._columns)})" + self.db.execute(query) + return self + + def to_sql(self) -> str: + """Devuelve la SQL sin ejecutarla. Útil para tests y debug.""" + self._require_columns() + table = self.dialect.quote_identifier(self.table_name) + return f"CREATE TABLE IF NOT EXISTS {table} ({', '.join(self._columns)})" diff --git a/shiba/dialects/__init__.py b/shiba/dialects/__init__.py new file mode 100644 index 0000000..9e9dc80 --- /dev/null +++ b/shiba/dialects/__init__.py @@ -0,0 +1,4 @@ +"""Dialectos SQL soportados por Shiba.""" +from shiba.dialects.base import Dialect + +__all__ = ["Dialect"] diff --git a/shiba/dialects/base.py b/shiba/dialects/base.py new file mode 100644 index 0000000..0cf5679 --- /dev/null +++ b/shiba/dialects/base.py @@ -0,0 +1,43 @@ +"""Contrato común a todos los dialectos SQL. + +Un :class:`Dialect` resuelve las diferencias sintácticas entre motores +(quoting de identificadores, placeholders, tipos, paginación). Los +builders del paquete :mod:`shiba.core` reciben un ``Dialect`` y producen +SQL portable. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class Dialect(ABC): + """Interfaz mínima que todo dialecto debe satisfacer.""" + + name: str + """Identificador corto del dialecto (``mysql``, ``postgres``...).""" + + placeholder: str + """Token de parámetro posicional (``%s``, ``?``, ``$1``...).""" + + @abstractmethod + def quote_identifier(self, name: str) -> str: + """Devuelve el identificador validado y citado.""" + + @abstractmethod + def map_type(self, declared: str) -> str: + """Traduce un tipo canónico de Shiba al SQL del dialecto. + + ``declared`` es la cadena que produce ``TableBuilder`` (p.ej. + ``"VARCHAR(255)"`` o ``"JSON"``). Para MySQL es identidad; para + otros motores se traduce (``BOOLEAN`` → ``TINYINT(1)`` en MySQL + viejo, ``JSON`` → ``JSONB`` en Postgres, etc.). + """ + + def render_limit(self, limit: int | None, offset: int | None) -> str: + """Cláusula ``LIMIT``/``OFFSET`` del dialecto. Default ANSI-ish.""" + parts: list[str] = [] + if limit is not None: + parts.append(f"LIMIT {int(limit)}") + if offset is not None: + parts.append(f"OFFSET {int(offset)}") + return " ".join(parts) diff --git a/shiba/dialects/mysql/__init__.py b/shiba/dialects/mysql/__init__.py new file mode 100644 index 0000000..f51a835 --- /dev/null +++ b/shiba/dialects/mysql/__init__.py @@ -0,0 +1,5 @@ +"""Dialecto MySQL/MariaDB.""" +from shiba.dialects.mysql.dialect import MySQLDialect +from shiba.dialects.mysql.driver import Database + +__all__ = ["Database", "MySQLDialect"] diff --git a/shiba/dialects/mysql/dialect.py b/shiba/dialects/mysql/dialect.py new file mode 100644 index 0000000..2b7685d --- /dev/null +++ b/shiba/dialects/mysql/dialect.py @@ -0,0 +1,19 @@ +"""Implementación :class:`~shiba.dialects.base.Dialect` para MySQL/MariaDB.""" +from __future__ import annotations + +from shiba.dialects.base import Dialect +from shiba.dialects.mysql.quoting import quote_identifier as _qi +from shiba.dialects.mysql.schema import map_type as _map_type + + +class MySQLDialect(Dialect): + """Backticks, placeholder ``%s``, sintaxis ``LIMIT n OFFSET m``.""" + + name = "mysql" + placeholder = "%s" + + def quote_identifier(self, name: str) -> str: + return _qi(name) + + def map_type(self, declared: str) -> str: + return _map_type(declared) diff --git a/shiba/dialects/mysql/driver.py b/shiba/dialects/mysql/driver.py new file mode 100644 index 0000000..8127038 --- /dev/null +++ b/shiba/dialects/mysql/driver.py @@ -0,0 +1,253 @@ +"""Wrapper sobre :mod:`pymysql` con manejo de errores Shiba. + +Diseño +------ +* No comparte cursor entre operaciones — cada ``execute`` abre el suyo. +* Las transacciones se gestionan con ``Database.transaction()`` como + context manager; fuera de una transacción cada operación auto-commit. +* Cualquier ``pymysql`` exception se traduce a un + :class:`~shiba.errors.ShibaError` con su :class:`ErrorCode`. +* Nada de ``print`` — sólo ``logging``. +""" +from __future__ import annotations + +import logging +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +import pymysql +import pymysql.cursors + +from shiba import error_codes +from shiba.dialects.mysql.quoting import quote_identifier +from shiba.error_codes import from_driver_exception +from shiba.errors import QueryError + +if TYPE_CHECKING: + from types import TracebackType + +logger = logging.getLogger("shiba.mysql") + + +class Database: + """Conexión MySQL con API estable de Shiba.""" + + def __init__( + self, + host: str, + port: int, + user: str, + password: str, + *, + database: str | None = None, + charset: str = "utf8mb4", + autoconnect: bool = True, + ) -> None: + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self.charset = charset + self._connection: pymysql.connections.Connection | None = None + self._in_transaction: bool = False + if autoconnect: + self.connect() + + # ------------------------------------------------------------------ + # Ciclo de vida + # ------------------------------------------------------------------ + + def connect(self) -> None: + """Abre la conexión. Idempotente.""" + if self._connection is not None and self._connection.open: + return + try: + self._connection = pymysql.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database, + charset=self.charset, + cursorclass=pymysql.cursors.DictCursor, + autocommit=True, + ) + except pymysql.err.OperationalError as exc: + code = from_driver_exception(exc) + raise code.build( + f"No se pudo conectar a {self.host}:{self.port}: {exc}", + details={"host": self.host, "port": self.port}, + ) from exc + except Exception as exc: # pragma: no cover - defensivo + raise error_codes.UNKNOWN_ERROR.build( + f"Error inesperado al conectar: {exc}", + ) from exc + + def close(self) -> None: + """Cierra la conexión si está abierta.""" + if self._connection is not None and self._connection.open: + self._connection.close() + self._connection = None + self._in_transaction = False + + def __enter__(self) -> Database: + self.connect() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + # ------------------------------------------------------------------ + # Acceso interno seguro a la conexión + # ------------------------------------------------------------------ + + @property + def _conn(self) -> pymysql.connections.Connection: + if self._connection is None or not self._connection.open: + raise error_codes.CONNECTION_NOT_OPEN.build( + "La conexión no está abierta. ¿Llamaste a close() ya?" + ) + return self._connection + + # ------------------------------------------------------------------ + # Ejecución + # ------------------------------------------------------------------ + + def execute( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + """Ejecuta ``query`` y devuelve filas (vacío si no hay rowset). + + Hace commit automático salvo que haya una transacción activa. + """ + if not query or not isinstance(query, str): + raise error_codes.EMPTY_QUERY.build( + "Se intentó ejecutar una query vacía o no string.", + query=str(query), + ) + conn = self._conn + try: + with conn.cursor() as cursor: + if params is None: + cursor.execute(query) + elif many: + cursor.executemany(query, params) + else: + cursor.execute(query, params) + try: + rows: list[dict[str, Any]] = list(cursor.fetchall()) + except pymysql.err.Error: + rows = [] + if not self._in_transaction: + conn.commit() + return rows + except pymysql.err.IntegrityError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Violación de integridad: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except pymysql.err.ProgrammingError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error de SQL: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except pymysql.err.OperationalError as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error operacional: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + except pymysql.err.Error as exc: + self._rollback_silent() + code = from_driver_exception(exc) + raise code.build( + f"Error de driver: {exc}", + query=query, + params=params, + cause=exc, + ) from exc + + # Alias retro-compatible con la API v1.x. + execute_query = execute + + def _rollback_silent(self) -> None: + if self._connection is None or not self._connection.open: + return + try: + self._connection.rollback() + except pymysql.err.Error: # pragma: no cover - best effort + logger.warning("rollback failed", exc_info=True) + + # ------------------------------------------------------------------ + # Transacciones + # ------------------------------------------------------------------ + + @contextmanager + def transaction(self) -> Iterator[Database]: + """Bloque transaccional. Commit al salir, rollback en excepción. + + No anidable en esta versión (se reservará para savepoints en + Fase 4). Lanza :data:`error_codes.TRANSACTION_ALREADY_ACTIVE` + si ya hay una activa. + """ + if self._in_transaction: + raise error_codes.TRANSACTION_ALREADY_ACTIVE.build() + conn = self._conn + conn.begin() + self._in_transaction = True + try: + yield self + except BaseException: + self._rollback_silent() + raise + else: + conn.commit() + finally: + self._in_transaction = False + + # ------------------------------------------------------------------ + # DDL/DML conveniencia + # ------------------------------------------------------------------ + + def create_database(self, name: str) -> Database: + """``CREATE DATABASE IF NOT EXISTS`` validando el nombre.""" + try: + self.execute(f"CREATE DATABASE IF NOT EXISTS {quote_identifier(name)}") + except QueryError as exc: + if exc.code is error_codes.INTEGRITY_DUPLICATE_KEY: + logger.info("database %s already exists", name) + else: + raise + self.database = name + return self + + def use_database(self, name: str) -> Database: + """``USE `` validando el nombre.""" + self.execute(f"USE {quote_identifier(name)}") + self.database = name + return self + + # Alias retro-compatible. + selected_database = use_database diff --git a/shiba/dialects/mysql/quoting.py b/shiba/dialects/mysql/quoting.py new file mode 100644 index 0000000..98fb7e4 --- /dev/null +++ b/shiba/dialects/mysql/quoting.py @@ -0,0 +1,13 @@ +"""Quoting de identificadores MySQL (backticks).""" +from __future__ import annotations + +from shiba.identifiers import validate_identifier + + +def quote_identifier(name: str) -> str: + """Valida y cita el identificador con backticks. + + Soporta nombres calificados ``schema.table`` citando cada parte. + """ + validate_identifier(name) + return ".".join(f"`{part}`" for part in name.split(".")) diff --git a/shiba/dialects/mysql/schema.py b/shiba/dialects/mysql/schema.py new file mode 100644 index 0000000..954f710 --- /dev/null +++ b/shiba/dialects/mysql/schema.py @@ -0,0 +1,39 @@ +"""Mapeo de tipos canónicos a SQL MySQL. + +Las declaraciones que produce :mod:`shiba.core.table_builder` ya son +sintaxis MySQL nativa (``VARCHAR(n)``, ``JSON``, ``DECIMAL(p,s)``), por +lo que este mapper es prácticamente identidad. La función existe para +que otros dialectos puedan traducir. +""" +from __future__ import annotations + +# Tipos canónicos aceptados por Shiba. Si el TableBuilder produce un +# tipo fuera de esta lista, el dialecto lo emitirá tal cual pero se +# considera "no soportado" para fines de validación. +SUPPORTED_TYPES: frozenset[str] = frozenset( + { + "INT", + "BIGINT", + "TINYINT", + "SMALLINT", + "VARCHAR", + "TEXT", + "CHAR", + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP", + "DECIMAL", + "FLOAT", + "DOUBLE", + "BOOLEAN", + "BLOB", + "ENUM", + "JSON", + } +) + + +def map_type(declared: str) -> str: + """Identidad para MySQL. Punto de extensión para otros dialectos.""" + return declared diff --git a/shiba/error_codes.py b/shiba/error_codes.py new file mode 100644 index 0000000..555619d --- /dev/null +++ b/shiba/error_codes.py @@ -0,0 +1,339 @@ +"""Catálogo de códigos de error propios de Shiba. + +Cada error público lleva un :class:`ErrorCode` estable e independiente +del driver subyacente. Esto permite al cliente: + +* identificar el error sin depender del mensaje (puede estar i18n), +* mapear errores nativos (pymysql, psycopg, etc.) a un dominio común, +* serializar el error sobre HTTP/JSON manteniendo semántica. + +Convención de códigos +--------------------- +* **SHIBA-1xxx** — Conexión / pool +* **SHIBA-2xxx** — Query / SQL +* **SHIBA-3xxx** — Esquema / DDL +* **SHIBA-4xxx** — Integridad +* **SHIBA-5xxx** — Transacciones / concurrencia +* **SHIBA-6xxx** — Datos / validación +* **SHIBA-9xxx** — Genéricos / no clasificados + +Uso +--- +.. code-block:: python + + from shiba import error_codes, ShibaError + + try: + ... + except ShibaError as e: + if e.code is error_codes.INTEGRITY_DUPLICATE_KEY: + ... +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from shiba.errors import ( + ConnectionError, + IntegrityError, + MissingDataError, + QueryError, + SchemaError, + ShibaError, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + + +@dataclass(frozen=True) +class ErrorCode: + """Descriptor inmutable de un error de la librería.""" + + code: str + name: str + default_message: str + exception_class: type[ShibaError] + + def __str__(self) -> str: + return f"{self.code} {self.name}" + + def build(self, message: str | None = None, **kwargs: Any) -> ShibaError: + """Construye (sin lanzar) la excepción asociada con este código.""" + msg = message or self.default_message + return self.exception_class(msg, code=self, **kwargs) + + def raise_(self, message: str | None = None, **kwargs: Any) -> None: + """Lanza la excepción asociada con este código. + + Para ``QueryError`` y descendientes acepta ``query=``, ``params=``, + ``cause=``. Todas las clases aceptan ``details=``. + """ + raise self.build(message, **kwargs) + + +# --- SHIBA-1xxx — Conexión / pool ------------------------------------------- + +CONNECTION_REFUSED = ErrorCode( + "SHIBA-1001", + "CONNECTION_REFUSED", + "No se pudo establecer la conexión con el servidor.", + ConnectionError, +) +CONNECTION_LOST = ErrorCode( + "SHIBA-1002", + "CONNECTION_LOST", + "La conexión con el servidor se perdió.", + ConnectionError, +) +AUTH_FAILED = ErrorCode( + "SHIBA-1003", + "AUTH_FAILED", + "Autenticación rechazada por el servidor.", + ConnectionError, +) +CONNECTION_TIMEOUT = ErrorCode( + "SHIBA-1004", + "CONNECTION_TIMEOUT", + "La conexión excedió el tiempo de espera.", + ConnectionError, +) +POOL_EXHAUSTED = ErrorCode( + "SHIBA-1005", + "POOL_EXHAUSTED", + "El pool de conexiones está agotado.", + ConnectionError, +) +CONNECTION_NOT_OPEN = ErrorCode( + "SHIBA-1006", + "CONNECTION_NOT_OPEN", + "Se intentó operar sobre una conexión cerrada o no inicializada.", + ConnectionError, +) + +# --- SHIBA-2xxx — Query / SQL ----------------------------------------------- + +QUERY_SYNTAX_ERROR = ErrorCode( + "SHIBA-2001", + "QUERY_SYNTAX_ERROR", + "Error de sintaxis SQL.", + QueryError, +) +UNKNOWN_TABLE = ErrorCode( + "SHIBA-2002", + "UNKNOWN_TABLE", + "La tabla referenciada no existe.", + QueryError, +) +UNKNOWN_COLUMN = ErrorCode( + "SHIBA-2003", + "UNKNOWN_COLUMN", + "La columna referenciada no existe.", + QueryError, +) +EMPTY_QUERY = ErrorCode( + "SHIBA-2004", + "EMPTY_QUERY", + "Se intentó ejecutar una query vacía.", + QueryError, +) +INVALID_QUERY_PARAMS = ErrorCode( + "SHIBA-2005", + "INVALID_QUERY_PARAMS", + "Los parámetros proporcionados no son válidos para la query.", + QueryError, +) +QUERY_EXECUTION_FAILED = ErrorCode( + "SHIBA-2099", + "QUERY_EXECUTION_FAILED", + "Falla genérica al ejecutar la query.", + QueryError, +) + +# --- SHIBA-3xxx — Schema / DDL ---------------------------------------------- + +INVALID_IDENTIFIER = ErrorCode( + "SHIBA-3001", + "INVALID_IDENTIFIER", + "El identificador SQL no cumple el formato permitido.", + SchemaError, +) +INVALID_OPERATOR = ErrorCode( + "SHIBA-3002", + "INVALID_OPERATOR", + "Operador SQL no permitido.", + SchemaError, +) +NO_COLUMNS_DEFINED = ErrorCode( + "SHIBA-3003", + "NO_COLUMNS_DEFINED", + "No se han definido columnas en el esquema.", + SchemaError, +) +TABLE_ALREADY_EXISTS = ErrorCode( + "SHIBA-3004", + "TABLE_ALREADY_EXISTS", + "La tabla ya existe.", + SchemaError, +) +DATABASE_ALREADY_EXISTS = ErrorCode( + "SHIBA-3005", + "DATABASE_ALREADY_EXISTS", + "La base de datos ya existe.", + SchemaError, +) +UNSUPPORTED_TYPE = ErrorCode( + "SHIBA-3006", + "UNSUPPORTED_TYPE", + "El tipo de columna no está soportado por este dialecto.", + SchemaError, +) + +# --- SHIBA-4xxx — Integridad ------------------------------------------------ + +INTEGRITY_DUPLICATE_KEY = ErrorCode( + "SHIBA-4001", + "INTEGRITY_DUPLICATE_KEY", + "Violación de clave única o primaria.", + IntegrityError, +) +INTEGRITY_FOREIGN_KEY = ErrorCode( + "SHIBA-4002", + "INTEGRITY_FOREIGN_KEY", + "Violación de restricción de clave foránea.", + IntegrityError, +) +INTEGRITY_NOT_NULL = ErrorCode( + "SHIBA-4003", + "INTEGRITY_NOT_NULL", + "Violación de NOT NULL.", + IntegrityError, +) +INTEGRITY_CHECK = ErrorCode( + "SHIBA-4004", + "INTEGRITY_CHECK", + "Violación de CHECK constraint.", + IntegrityError, +) + +# --- SHIBA-5xxx — Transacciones / concurrencia ------------------------------ + +DEADLOCK_DETECTED = ErrorCode( + "SHIBA-5001", + "DEADLOCK_DETECTED", + "El servidor detectó un deadlock y abortó la transacción.", + QueryError, +) +SERIALIZATION_FAILURE = ErrorCode( + "SHIBA-5002", + "SERIALIZATION_FAILURE", + "La transacción no pudo serializarse y debe reintentarse.", + QueryError, +) +NO_ACTIVE_TRANSACTION = ErrorCode( + "SHIBA-5003", + "NO_ACTIVE_TRANSACTION", + "Operación de transacción sin transacción activa.", + QueryError, +) +TRANSACTION_ALREADY_ACTIVE = ErrorCode( + "SHIBA-5004", + "TRANSACTION_ALREADY_ACTIVE", + "Ya hay una transacción activa en esta conexión.", + QueryError, +) + +# --- SHIBA-6xxx — Datos / validación ---------------------------------------- + +MISSING_REQUIRED_DATA = ErrorCode( + "SHIBA-6001", + "MISSING_REQUIRED_DATA", + "Faltan datos requeridos.", + MissingDataError, +) +INVALID_DATA_FORMAT = ErrorCode( + "SHIBA-6002", + "INVALID_DATA_FORMAT", + "Formato de datos inválido.", + MissingDataError, +) + +# --- SHIBA-9xxx — Genéricos ------------------------------------------------- + +UNKNOWN_ERROR = ErrorCode( + "SHIBA-9001", + "UNKNOWN_ERROR", + "Error no clasificado.", + ShibaError, +) +NOT_IMPLEMENTED = ErrorCode( + "SHIBA-9999", + "NOT_IMPLEMENTED", + "Funcionalidad no implementada en este dialecto.", + ShibaError, +) + + +# Registro indexable por código y por nombre (útil para serialización). +ALL_CODES: tuple[ErrorCode, ...] = tuple( + v for v in globals().values() if isinstance(v, ErrorCode) +) +BY_CODE: Mapping[str, ErrorCode] = {c.code: c for c in ALL_CODES} +BY_NAME: Mapping[str, ErrorCode] = {c.name: c for c in ALL_CODES} + + +# --- Mapeo desde errores nativos del driver --------------------------------- + +# Códigos numéricos MySQL/MariaDB → ErrorCode. Sólo los más comunes; el resto +# cae en QUERY_EXECUTION_FAILED. +# Ref: https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html +_MYSQL_ERRNO_MAP: dict[int, ErrorCode] = { + 1045: AUTH_FAILED, + 1049: UNKNOWN_TABLE, # Unknown database (cercano) + 1051: UNKNOWN_TABLE, + 1054: UNKNOWN_COLUMN, + 1062: INTEGRITY_DUPLICATE_KEY, + 1064: QUERY_SYNTAX_ERROR, + 1146: UNKNOWN_TABLE, + 1213: DEADLOCK_DETECTED, + 1216: INTEGRITY_FOREIGN_KEY, + 1217: INTEGRITY_FOREIGN_KEY, + 1364: INTEGRITY_NOT_NULL, + 1451: INTEGRITY_FOREIGN_KEY, + 1452: INTEGRITY_FOREIGN_KEY, + 2002: CONNECTION_REFUSED, + 2003: CONNECTION_REFUSED, + 2006: CONNECTION_LOST, + 2013: CONNECTION_LOST, + 2059: AUTH_FAILED, + 3819: INTEGRITY_CHECK, +} + + +def from_driver_exception(exc: BaseException) -> ErrorCode: + """Traduce una excepción nativa del driver a nuestro :class:`ErrorCode`. + + Si no se reconoce el error, devuelve :data:`QUERY_EXECUTION_FAILED`. + """ + # pymysql usa `exc.args = (errno, msg)`; psycopg expone `.pgcode`. + errno: int | None = None + args = getattr(exc, "args", None) + if args and isinstance(args, tuple) and args and isinstance(args[0], int): + errno = args[0] + + if errno is not None and errno in _MYSQL_ERRNO_MAP: + return _MYSQL_ERRNO_MAP[errno] + + # Heurística por nombre de clase para drivers que no exponen errno. + name = type(exc).__name__ + if "IntegrityError" in name: + return INTEGRITY_DUPLICATE_KEY + if "OperationalError" in name: + return CONNECTION_LOST + if "ProgrammingError" in name: + return QUERY_SYNTAX_ERROR + if "InterfaceError" in name: + return CONNECTION_NOT_OPEN + + return QUERY_EXECUTION_FAILED diff --git a/shiba/errors.py b/shiba/errors.py new file mode 100644 index 0000000..d281510 --- /dev/null +++ b/shiba/errors.py @@ -0,0 +1,99 @@ +"""Jerarquía de excepciones de Shiba. + +Todas las excepciones públicas heredan de :class:`ShibaError`. Cada +instancia puede llevar un :class:`~shiba.error_codes.ErrorCode` estable +que permite al cliente identificar el error sin depender del mensaje. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from shiba.error_codes import ErrorCode + + +class ShibaError(Exception): + """Excepción base de la librería. + + :param message: descripción legible. + :param code: :class:`ErrorCode` opcional. Si se omite, la excepción + sigue siendo válida pero el cliente no podrá hacer match por + código. + :param details: información estructurada adicional (no PII). + """ + + def __init__( + self, + message: str, + *, + code: ErrorCode | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message) + self.message = message + self.code = code + self.details: dict[str, Any] = details or {} + + def __str__(self) -> str: + if self.code is None: + return self.message + return f"[{self.code.code}] {self.message}" + + def to_dict(self) -> dict[str, Any]: + """Serializa a dict para responder por HTTP/JSON.""" + return { + "code": self.code.code if self.code else None, + "name": self.code.name if self.code else None, + "message": self.message, + "details": self.details, + } + + +class ConnectionError(ShibaError): # noqa: A001 - sombrea builtin a propósito + """Falla al abrir o mantener la conexión con la base de datos.""" + + +class QueryError(ShibaError): + """Falla al ejecutar una sentencia SQL. + + Adjunta la query y los parámetros para facilitar el debug; nunca se + deben loggear directamente si contienen PII. + """ + + def __init__( + self, + message: str, + *, + code: ErrorCode | None = None, + details: dict[str, Any] | None = None, + query: str | None = None, + params: Any = None, + cause: BaseException | None = None, + ) -> None: + super().__init__(message, code=code, details=details) + self.query = query + self.params = params + if cause is not None: + self.__cause__ = cause + + +class IntegrityError(QueryError): + """Violación de constraint (PK duplicada, FK, NOT NULL, CHECK).""" + + +class SchemaError(ShibaError): + """Error en la definición de esquema (DDL inválida, identificador no permitido).""" + + +class MissingDataError(ShibaError): + """Faltan datos requeridos para construir una operación.""" + + +__all__ = [ + "ConnectionError", + "IntegrityError", + "MissingDataError", + "QueryError", + "SchemaError", + "ShibaError", +] diff --git a/shiba/identifiers.py b/shiba/identifiers.py new file mode 100644 index 0000000..964b65d --- /dev/null +++ b/shiba/identifiers.py @@ -0,0 +1,69 @@ +"""Validación de identificadores SQL. + +Cualquier nombre de tabla/columna que toque la librería pasa por +:func:`validate_identifier` *antes* de concatenarse a SQL. El quoting +final (backticks, comillas dobles, corchetes) lo provee el `Dialect`. +""" +from __future__ import annotations + +import re + +from shiba import error_codes + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]{0,63}$") + + +def validate_identifier(name: str, *, kind: str = "identifier") -> str: + """Acepta el identificador o lanza :class:`SchemaError`. + + Permite nombres ``schema.table`` validando cada parte por separado. + """ + if not isinstance(name, str) or not name: + error_codes.INVALID_IDENTIFIER.raise_( + f"{kind} vacío o no string: {name!r}", + details={"kind": kind, "value": repr(name)}, + ) + + for part in name.split("."): + if not _IDENT_RE.match(part): + error_codes.INVALID_IDENTIFIER.raise_( + f"{kind} inválido: {name!r} " + "(sólo [A-Za-z_][A-Za-z0-9_]*, max 64 chars)", + details={"kind": kind, "value": name}, + ) + return name + + +_ALLOWED_OPERATORS = frozenset( + { + "=", + "!=", + "<>", + "<", + "<=", + ">", + ">=", + "LIKE", + "NOT LIKE", + "IN", + "NOT IN", + "IS", + "IS NOT", + } +) + + +def validate_operator(op: str) -> str: + """Acepta el operador o lanza :class:`SchemaError`.""" + if not isinstance(op, str): # defensa para callers no tipados + error_codes.INVALID_OPERATOR.raise_( # type: ignore[unreachable] + f"operador no string: {op!r}", + details={"value": repr(op)}, + ) + normalized = op.strip().upper() if op.strip().isalpha() else op.strip() + if normalized not in _ALLOWED_OPERATORS: + error_codes.INVALID_OPERATOR.raise_( + f"operador no permitido: {op!r}", + details={"value": op, "allowed": sorted(_ALLOWED_OPERATORS)}, + ) + return normalized diff --git a/shibamysql/__init__.py b/shibamysql/__init__.py index 28159e5..caa2911 100644 --- a/shibamysql/__init__.py +++ b/shibamysql/__init__.py @@ -1,20 +1,42 @@ -from shibamysql.database import Database -from shibamysql.table_builder import TableBuilder -from shibamysql.query_builder import QueryBuilder +"""Shim de compatibilidad con la API ``shibamysql`` v1.x. -class ShibaConnection: +Re-exporta los símbolos públicos desde :mod:`shiba`. Programar contra +este módulo emite :class:`DeprecationWarning`; migrar a ``import shiba``. +""" +from __future__ import annotations - def __init__(self, host, port, user, password) -> None: - self.db = Database(host, port, user, password) - - def create_database(self, database): - return self.db.create_database(database) - - def use_database(self, database): - return self.db.selected_database(database) - - def create_table(self, table_name): - return TableBuilder(self.db, table_name) - - def table(self, table_name): - return QueryBuilder(self.db, table_name) \ No newline at end of file +import warnings + +from shiba import ( + ConnectionError, + Database, + IntegrityError, + MissingDataError, + QueryBuilder, + QueryError, + SchemaError, + ShibaConnection, + ShibaError, + TableBuilder, + error_codes, +) + +warnings.warn( + "`shibamysql` está deprecado desde v2.0; importa desde `shiba` en su lugar.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "ConnectionError", + "Database", + "IntegrityError", + "MissingDataError", + "QueryBuilder", + "QueryError", + "SchemaError", + "ShibaConnection", + "ShibaError", + "TableBuilder", + "error_codes", +] diff --git a/shibamysql/database.py b/shibamysql/database.py deleted file mode 100644 index 6678b3b..0000000 --- a/shibamysql/database.py +++ /dev/null @@ -1,115 +0,0 @@ -import pymysql - -class Database: - - """ - The above function is a constructor that initializes the host, port, user, and password - attributes of an object, and establishes a connection to a database. - - :param host: The `host` parameter represents the hostname or IP address of the database server - you want to connect to. It is used to specify the location of the database server - :param port: The `port` parameter is used to specify the port number for the database - connection. It is typically a number that corresponds to a specific service or protocol on the - host machine. For example, the default port for MySQL is 3306, while the default port for - PostgreSQL is 5432 - :param user: The "user" parameter in the above code represents the username used to authenticate - the connection to a database server. It is typically used in combination with the "password" - parameter to establish a secure connection - :param password: The `password` parameter is used to store the password for the user to connect - to the database. It is a string that represents the password - """ - def __init__(self, host, port, user, password) -> None: - self.host = host - self.port = port - self.user = user - self.password = password - self.connection = None - self.cursor = None - self.database = None - self.query = "" - self.connect() - - """ - The function `connect` establishes a connection to a database using the provided host, port, - user, and password. - """ - def connect (self): - try: - self.connection = pymysql.connect( - host=self.host, - port=self.port, - user=self.user, - password=self.password - ) - self.cursor = self.connection.cursor(pymysql.cursors.DictCursor) - except Exception as e: - print (f"Error connecting to the database: {str(e)}") - - """ - The function creates a database with the given name, handling exceptions for existing databases - and other errors. - - :param database: The `database` parameter is a string that represents the name of the database - that you want to create - :return: The method is returning the instance of the class itself (self) after creating the - database. - """ - def create_database(self, database): - - try: - self.cursor.execute(f"CREATE DATABASE {database}") - self.database = database - return self - except pymysql.err.ProgrammingError as e: - if "database exists" in str(e): - print(f"The database '{database}' already exists.") - else: - print(f"Error creating database: {str(e)}") - except Exception as e: - print(f"Error creating database: {str(e)}") - - """ - The function `selected_database` sets the current database to the specified database. - - :param database: The `database` parameter is the name of the database that you want to select - and use - :return: The method is returning the instance of the class itself (self) after setting the - selected database. - """ - def selected_database(self, database): - try: - self.database = self.cursor.execute(f"USE {database}") - return self - except Exception as e: - print (f"Error using database {database}: {str(e)}") - - """ - The function executes a SQL query with optional parameters and returns the result. - - :param query: The query parameter is a string that represents the SQL query you want to execute. - It can be any valid SQL statement, such as SELECT, INSERT, UPDATE, DELETE, etc - :param params: The `params` parameter is used to pass values to the query as parameters. It is - an optional parameter and can be used when the query contains placeholders for values that need - to be dynamically provided - :param many: The "many" parameter is a boolean flag that indicates whether the query should be - executed multiple times with different sets of parameters. If set to True, the "params" argument - should be a list of tuples, where each tuple contains the parameters for a single execution of - the query. If set to False, defaults to False (optional) - :return: the result of the query execution, which is stored in the variable "result". - """ - def execute_query(self, query, params=None, many=False): - try: - if params is None: - self.cursor.execute(query) - else: - if many: - self.cursor.executemany(query, params) - else: - self.cursor.execute(query, params) - self.connection.commit() - result = self.cursor.fetchall() - return result - except Exception as e: - print(f"Error Query: {str(e)}") - self.connection.rollback() - return None \ No newline at end of file diff --git a/shibamysql/missing_data_error.py b/shibamysql/missing_data_error.py index 7badfec..65974a8 100644 --- a/shibamysql/missing_data_error.py +++ b/shibamysql/missing_data_error.py @@ -1,4 +1,4 @@ -class MissingDataError(Exception): - def __init__(self, message): - self.message = message - super().__init__(self.message) +"""Shim de compatibilidad. El contenido vive ahora en :mod:`shiba.errors`.""" +from shiba.errors import MissingDataError + +__all__ = ["MissingDataError"] diff --git a/shibamysql/query_builder.py b/shibamysql/query_builder.py deleted file mode 100644 index 73bd5d5..0000000 --- a/shibamysql/query_builder.py +++ /dev/null @@ -1,178 +0,0 @@ -from shibamysql.missing_data_error import MissingDataError - - -class QueryBuilder: - def __init__(self, db, table_name): - self.db = db - self.table_name = table_name - self.query = None - self.where_conditions = [] - self.selected_columns = [] - self.join_clauses = [] - - def _execute_query(self, params=None, many=False): - result = self.db.execute_query(self.query, params, many) - return result - - def join(self, table_name, column1, operator, column2): - if not table_name or not column1 or not operator or not column2: - raise MissingDataError( - "Falta uno o más parámetros requeridos para la operación JOIN.") - - join_clause = f"JOIN {table_name} ON ({column1} {operator} {column2})" - self.join_clauses.append(join_clause) - return self - - def left_join(self, table_name, column1, operator, column2): - if not table_name or not column1 or not operator or not column2: - raise MissingDataError( - "Falta uno o más parámetros requeridos para la operación LEFT JOIN.") - - join_clause = f"LEFT JOIN {table_name} ON {column1} {operator} {column2}" - self.join_clauses.append(join_clause) - return self - - def right_join(self, table_name, column1, operator, column2): - if not table_name or not column1 or not operator or not column2: - raise MissingDataError( - "Falta uno o más parámetros requeridos para la operación RIGHT JOIN.") - - join_clause = f"RIGHT JOIN {table_name} ON {column1} {operator} {column2}" - self.join_clauses.append(join_clause) - return self - - def cross_join(self, table_name): - if not table_name: - raise MissingDataError( - "Falta uno o más parámetros requeridos para la operación CROSS JOIN.") - - join_clause = f"CROSS JOIN {table_name}" - self.join_clauses.append(join_clause) - return self - - def inner_join(self, table_name, column1, operator, column2): - if not table_name or not column1 or not operator or not column2: - raise MissingDataError( - "Falta uno o más parámetros requeridos para la operación INNER JOIN.") - - join_clause = f"INNER JOIN {table_name} ON {column1} {operator} {column2}" - self.join_clauses.append(join_clause) - return self - - def select(self, *columns): - if not columns: - raise MissingDataError( - "Debe especificar al menos una columna para la operación SELECT.") - - self.selected_columns.extend(columns) - return self - - def where(self, *args): - if isinstance(args, list): - return self._whereArray(args) - - if len(args) == 2: - column, value = args - operator = '=' - elif len(args) == 3: - column, operator, value = args - else: - raise ValueError("Invalid number of arguments for 'where' method") - - self.where_conditions.append(f"{column} {operator} '{value}'") - - return self - - def _whereArray(self, array, boolean='AND'): - for condition in array: - if isinstance(condition, list): - if len(condition) == 2: - column, value = condition - operator = '=' - elif len(condition) == 3: - column, operator, value = condition - else: - raise MissingDataError("Invalid number of arguments for 'where' method") - - if isinstance(value, str): - value = f"{value}" - - a = self.where(column, operator, value) - return self - - def get(self): - select_clause = "*" - if self.selected_columns: - select_clause = ", ".join(self.selected_columns) - - join_clause = " ".join(self.join_clauses) - - where = "" - if self.where_conditions: - where = "WHERE " - where_conditions = " AND ".join(self.where_conditions) - where += where_conditions - - print (where) - - query = f"SELECT {select_clause} FROM {self.table_name} {join_clause} {where}" - result = self.db.execute_query(query) - - return result if result is not None else None - - """ - The function `insert` checks the data format and calls the appropriate method for inserting - either a single item or multiple items into a database. - - :param data: The `data` parameter is the data that you want to insert into the database. It can - be either a single dictionary or a list of dictionaries - :return: The method `insert` returns the result of either `_insert_many` or `_insert_single` - depending on the type of `data` being passed in. - """ - - def insert(self, data): - if isinstance(data, list): - self.many = True - return self._insert_many(data) - elif isinstance(data, dict): - self.many = False - return self._insert_single(data) - else: - print("Invalid data format for insertion.") - return None - - """ - The function inserts a single row of data into a database table. - - :param data: The `data` parameter is a dictionary that contains the column names as keys and the - corresponding values that you want to insert into the table - :return: The method is returning the result of executing the query with the provided values. - """ - - def _insert_single(self, data): - columns = ', '.join(data.keys()) - placeholders = ', '.join(['%s'] * len(data)) - values = tuple(data.values()) - self.query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})" - return self._execute_query(values) - - """ - The function inserts multiple rows of data into a table in a database. - - :param data_list: A list of dictionaries, where each dictionary represents a row of data to be - inserted into the table. Each dictionary should have keys that correspond to the column names in - the table, and the values should be the data to be inserted into those columns - :return: the result of the `_execute_query` method with the `values` parameter set to the list - of tuples created from the `data_list` argument. The `many` parameter is set to `True`, - indicating that multiple rows will be inserted. - """ - - def _insert_many(self, data_list): - if len(data_list) == 0: - return None - - columns = ', '.join(data_list[0].keys()) - placeholders = ', '.join(['%s'] * len(data_list[0])) - values = [tuple(data.values()) for data in data_list] - self.query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})" - return self._execute_query(values, many=True) diff --git a/shibamysql/table_builder.py b/shibamysql/table_builder.py deleted file mode 100644 index ad75a83..0000000 --- a/shibamysql/table_builder.py +++ /dev/null @@ -1,123 +0,0 @@ - -from shibamysql.missing_data_error import MissingDataError - -class TableBuilder: - def __init__(self, database, table_name) -> None: - self.db = database - self.table_name = table_name - self.columns = [] - - def increments(self, column_name='id', primary_key=False): - if not column_name: - raise MissingDataError("El nombre de la columna no puede estar vacío.") - - self.columns.append( - f"{column_name} INT AUTO_INCREMENT {'' if not primary_key else 'PRIMARY KEY'}" - ) - return self - - def primary(self): - if not self.columns: - raise MissingDataError("No se ha agregado ninguna columna.") - - current_column = self.columns[-1] - if 'PRIMARY KEY' not in current_column: - self.columns[-1] += " PRIMARY KEY" - return self - - def unique(self): - if not self.columns: - raise MissingDataError("No se ha agregado ninguna columna.") - - self.columns[-1] += " UNIQUE" - - def foreign(self, foreign_name=None, table_name=None, column_name=None): - if not self.columns: - raise MissingDataError("No se ha agregado ninguna columna.") - if not foreign_name or not table_name or not column_name: - raise MissingDataError("Falta uno o más parámetros requeridos para la restricción FOREIGN KEY.") - - current_column = self.columns[-1] - current_column_name = current_column.split()[0] # Obtener la posición de la columna - self.columns[-1] += f", CONSTRAINT {foreign_name} FOREIGN KEY ({current_column_name}) REFERENCES {table_name}({column_name})" - return self - - """ TYPE """ - - def integer(self, column_name, length=None): - self.columns.append( - f"{column_name} {'INT' if length is None else f'INT({length})'}") - return self - - def string(self, column_name, length=255): - self.columns.append(f"{column_name} VARCHAR({length})") - return self - - def text(self, column_name): - self.columns.append(f"{column_name} TEXT") - return self - - def char(self, column_name, length=1): - self.columns.append(f"{column_name} CHAR({length})") - return self - - def date(self, column_name): - self.columns.append(f"{column_name} DATE") - return self - - def datetime(self, column_name): - self.columns.append(f"{column_name} DATETIME") - return self - - def time(self, column_name): - self.columns.append(f"{column_name} TIME") - return self - - def timestamp(self, column_name): - self.columns.append(f"{column_name} TIMESTAMP") - return self - - def decimal(self, column_name, precision=10, scale=2): - self.columns.append(f"{column_name} DECIMAL({precision}, {scale})") - return self - - def floats(self, column_name, precision=10, scale=2): - self.columns.append(f"{column_name} FLOAT({precision}, {scale})") - return self - - def boolean(self, column_name): - self.columns.append(f"{column_name} BOOLEAN") - return self - - def binary(self, column_name, length=None): - self.columns.append(f"{column_name} {'BLOB' if length is None else f'BLOB({length})'}") - return self - - def enum(self, column_name, choices): - choices_str = ', '.join(f"'{choice}'" for choice in choices) - self.columns.append(f"{column_name} ENUM({choices_str})") - return self - - """ NULL / NOT NULL """ - - def nullable(self): - self.columns[-1] += " NULL" - return self - - def not_nullable(self): - self.columns[-1] += " NOT NULL" - return self - - def build(self): - if not self.columns: - raise MissingDataError("No se ha agregado ninguna columna.") - - query = f"CREATE TABLE IF NOT EXISTS {str(self.table_name)} (" - query += ", ".join(self.columns) - query += ");" - print(query) - result = self.db.execute_query(query) - if result is not None: - print(f"Table Created {self.table_name} Successfully") - else: - return False \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ddb2dab --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +"""Fixtures comunes — mock del Database para tests offline.""" +from __future__ import annotations + +from typing import Any + +import pytest + +from shiba.dialects.mysql import MySQLDialect + + +class FakeDatabase: + """Captura las queries ejecutadas sin tocar MySQL.""" + + def __init__(self) -> None: + self.calls: list[tuple[str, Any, bool]] = [] + self.result: list[dict[str, Any]] = [] + + def execute( + self, + query: str, + params: Any = None, + *, + many: bool = False, + ) -> list[dict[str, Any]]: + self.calls.append((query, params, many)) + return self.result + + @property + def last_call(self) -> tuple[str, Any, bool]: + return self.calls[-1] + + +@pytest.fixture +def fake_db() -> FakeDatabase: + return FakeDatabase() + + +@pytest.fixture +def dialect() -> MySQLDialect: + return MySQLDialect() diff --git a/tests/test_error_codes.py b/tests/test_error_codes.py new file mode 100644 index 0000000..c4ff2a0 --- /dev/null +++ b/tests/test_error_codes.py @@ -0,0 +1,73 @@ +"""Catálogo de error codes y mapper desde drivers nativos.""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.errors import ( + ConnectionError, + IntegrityError, + QueryError, + ShibaError, +) + + +def test_code_uniqueness() -> None: + seen: set[str] = set() + for code in error_codes.ALL_CODES: + assert code.code.startswith("SHIBA-") + assert code.code not in seen + seen.add(code.code) + + +def test_build_and_raise_attach_code() -> None: + exc = error_codes.INTEGRITY_DUPLICATE_KEY.build("dup", query="INSERT ...", params=(1,)) + assert isinstance(exc, IntegrityError) + assert exc.code is error_codes.INTEGRITY_DUPLICATE_KEY + assert exc.query == "INSERT ..." + assert exc.params == (1,) + assert str(exc).startswith("[SHIBA-4001]") + + +def test_raise_helper_raises_right_class() -> None: + with pytest.raises(ConnectionError) as ei: + error_codes.CONNECTION_REFUSED.raise_(details={"host": "x"}) + assert ei.value.code is error_codes.CONNECTION_REFUSED + assert ei.value.details == {"host": "x"} + + +def test_shiba_error_to_dict() -> None: + exc = error_codes.UNKNOWN_COLUMN.build("col foo missing", query="SELECT foo FROM t") + d = exc.to_dict() + assert d == { + "code": "SHIBA-2003", + "name": "UNKNOWN_COLUMN", + "message": "col foo missing", + "details": {}, + } + + +def test_lookup_by_code_and_name() -> None: + assert error_codes.BY_CODE["SHIBA-4001"] is error_codes.INTEGRITY_DUPLICATE_KEY + assert error_codes.BY_NAME["INTEGRITY_DUPLICATE_KEY"] is error_codes.INTEGRITY_DUPLICATE_KEY + + +def test_mysql_errno_mapping() -> None: + class FakeMySQLError(Exception): + pass + + exc = FakeMySQLError(1062, "Duplicate entry") + assert error_codes.from_driver_exception(exc) is error_codes.INTEGRITY_DUPLICATE_KEY + + +def test_unknown_driver_error_falls_back() -> None: + class WeirdError(Exception): + pass + + assert error_codes.from_driver_exception(WeirdError("?")) is error_codes.QUERY_EXECUTION_FAILED + + +def test_shibaerror_is_base() -> None: + exc = error_codes.QUERY_SYNTAX_ERROR.build() + assert isinstance(exc, ShibaError) + assert isinstance(exc, QueryError) diff --git a/tests/test_identifiers.py b/tests/test_identifiers.py new file mode 100644 index 0000000..031234f --- /dev/null +++ b/tests/test_identifiers.py @@ -0,0 +1,54 @@ +"""Validación de identificadores y operadores.""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.errors import SchemaError +from shiba.identifiers import validate_identifier, validate_operator + + +@pytest.mark.parametrize( + "name", + ["users", "_private", "users.id", "schema.table", "a", "x123"], +) +def test_valid_identifiers(name: str) -> None: + assert validate_identifier(name) == name + + +@pytest.mark.parametrize( + "name", + [ + "users; DROP TABLE x", + "users--", + "1users", + "", + "user name", + "users'", + "`users`", + "a" * 65, + ], +) +def test_invalid_identifiers_raise(name: str) -> None: + with pytest.raises(SchemaError) as ei: + validate_identifier(name) + assert ei.value.code is error_codes.INVALID_IDENTIFIER + + +def test_invalid_identifier_carries_details() -> None: + with pytest.raises(SchemaError) as ei: + validate_identifier("bad;name", kind="column") + assert ei.value.code is error_codes.INVALID_IDENTIFIER + assert ei.value.details["kind"] == "column" + + +@pytest.mark.parametrize("op", ["=", "!=", "<>", "<", ">=", "LIKE", "IN", "IS NOT"]) +def test_valid_operators(op: str) -> None: + assert validate_operator(op) == op.upper().strip() if op.isalpha() else op + + +@pytest.mark.parametrize("op", ["==", "; DROP", "OR 1=1", ""]) +def test_invalid_operators(op: str) -> None: + with pytest.raises(SchemaError) as ei: + validate_operator(op) + assert ei.value.code is error_codes.INVALID_OPERATOR diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py new file mode 100644 index 0000000..e1a3539 --- /dev/null +++ b/tests/test_query_builder.py @@ -0,0 +1,139 @@ +"""QueryBuilder genera SQL parametrizado correcto.""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.core.query_builder import QueryBuilder +from shiba.errors import MissingDataError, SchemaError + + +def test_select_quotes_columns(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).select("id", "name").get() + sql, params, many = fake_db.last_call + assert sql.startswith("SELECT `id`, `name` FROM `users`") + assert params is None + assert many is False + + +def test_where_parametrizes_value(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).where("name", "John").get() + sql, params, _ = fake_db.last_call + assert "WHERE `name` = %s" in sql + assert params == ("John",) + + +def test_where_rejects_injection_in_value_via_params(fake_db, dialect) -> None: + # El valor con inyección llega como parámetro, no concatenado. + payload = "x' OR '1'='1" + QueryBuilder(fake_db, "users", dialect=dialect).where("name", payload).get() + sql, params, _ = fake_db.last_call + assert payload not in sql + assert params == (payload,) + + +def test_where_rejects_injection_in_column(fake_db, dialect) -> None: + with pytest.raises(SchemaError) as ei: + QueryBuilder(fake_db, "users", dialect=dialect).where("name; DROP", "x").get() + assert ei.value.code is error_codes.INVALID_IDENTIFIER + + +def test_where_rejects_bad_operator(fake_db, dialect) -> None: + with pytest.raises(SchemaError) as ei: + QueryBuilder(fake_db, "users", dialect=dialect).where("name", "OR 1=1", "x").get() + assert ei.value.code is error_codes.INVALID_OPERATOR + + +def test_where_in_expands_placeholders(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).where("id", "IN", [1, 2, 3]).get() + sql, params, _ = fake_db.last_call + assert "`id` IN (%s, %s, %s)" in sql + assert params == (1, 2, 3) + + +def test_where_array_does_not_crash_on_three_element_condition(fake_db, dialect) -> None: + # Regresión: en v1.x _whereArray accedía a `value` sin definir cuando + # la condición tenía 3 elementos. + QueryBuilder(fake_db, "users", dialect=dialect).where( + [["age", ">", 18], ["name", "Alice"]] + ).get() + sql, params, _ = fake_db.last_call + assert "`age` > %s" in sql + assert "`name` = %s" in sql + assert params == (18, "Alice") + + +def test_join_quotes_identifiers(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).left_join( + "orders", "users.id", "=", "orders.user_id" + ).get() + sql, _, _ = fake_db.last_call + assert "LEFT JOIN `orders` ON `users`.`id` = `orders`.`user_id`" in sql + + +def test_order_by_and_limit(fake_db, dialect) -> None: + ( + QueryBuilder(fake_db, "users", dialect=dialect) + .order_by("name", "DESC") + .limit(10) + .offset(5) + .get() + ) + sql, _, _ = fake_db.last_call + assert "ORDER BY `name` DESC" in sql + assert "LIMIT 10" in sql + assert "OFFSET 5" in sql + + +def test_insert_single(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).insert({"name": "John", "age": 30}) + sql, params, many = fake_db.last_call + assert sql == "INSERT INTO `users` (`name`, `age`) VALUES (%s, %s)" + assert params == ("John", 30) + assert many is False + + +def test_insert_many(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).insert( + [{"name": "A", "age": 1}, {"name": "B", "age": 2}] + ) + sql, params, many = fake_db.last_call + assert sql == "INSERT INTO `users` (`name`, `age`) VALUES (%s, %s)" + assert params == [("A", 1), ("B", 2)] + assert many is True + + +def test_insert_many_rejects_mismatched_keys(fake_db, dialect) -> None: + with pytest.raises(MissingDataError) as ei: + QueryBuilder(fake_db, "users", dialect=dialect).insert( + [{"name": "A", "age": 1}, {"age": 2, "name": "B"}] + ) + assert ei.value.code is error_codes.INVALID_DATA_FORMAT + + +def test_update(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).where("id", 1).update({"name": "X"}) + sql, params, _ = fake_db.last_call + assert sql == "UPDATE `users` SET `name` = %s WHERE `id` = %s" + assert params == ("X", 1) + + +def test_delete_requires_where(fake_db, dialect) -> None: + with pytest.raises(MissingDataError) as ei: + QueryBuilder(fake_db, "users", dialect=dialect).delete() + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +def test_delete_with_where(fake_db, dialect) -> None: + QueryBuilder(fake_db, "users", dialect=dialect).where("id", 1).delete() + sql, params, _ = fake_db.last_call + assert sql == "DELETE FROM `users` WHERE `id` = %s" + assert params == (1,) + + +def test_count(fake_db, dialect) -> None: + fake_db.result = [{"cnt": 42}] + n = QueryBuilder(fake_db, "users", dialect=dialect).where("active", True).count() + assert n == 42 + sql, _, _ = fake_db.last_call + assert sql.startswith("SELECT COUNT(*) AS cnt FROM `users`") diff --git a/tests/test_table_builder.py b/tests/test_table_builder.py new file mode 100644 index 0000000..e402f2b --- /dev/null +++ b/tests/test_table_builder.py @@ -0,0 +1,82 @@ +"""TableBuilder produce DDL coherente y mantiene su fluent API.""" +from __future__ import annotations + +import pytest + +from shiba import error_codes +from shiba.core.table_builder import TableBuilder +from shiba.errors import MissingDataError, SchemaError + + +def test_basic_create(fake_db, dialect) -> None: + sql = ( + TableBuilder(fake_db, "users", dialect=dialect) + .increments("id", primary_key=True) + .string("name", 20) + .integer("age") + .to_sql() + ) + assert "CREATE TABLE IF NOT EXISTS `users`" in sql + assert "`id` INT AUTO_INCREMENT PRIMARY KEY" in sql + assert "`name` VARCHAR(20)" in sql + assert "`age` INT" in sql + + +def test_unique_returns_self_fluent_chain(fake_db, dialect) -> None: + # Regresión: v1.x unique() no retornaba self y rompía el chain. + tb = TableBuilder(fake_db, "users", dialect=dialect).string("email").unique().not_nullable() + assert isinstance(tb, TableBuilder) + sql = tb.to_sql() + assert "`email` VARCHAR(255) UNIQUE NOT NULL" in sql + + +def test_build_returns_self(fake_db, dialect) -> None: + tb = TableBuilder(fake_db, "users", dialect=dialect).integer("id") + assert tb.build() is tb + + +def test_enum_escapes_choices(fake_db, dialect) -> None: + sql = ( + TableBuilder(fake_db, "u", dialect=dialect) + .enum("status", ["active", "inactive", "O'Brien"]) + .to_sql() + ) + assert "ENUM('active', 'inactive', 'O''Brien')" in sql + + +def test_invalid_column_name_raises(fake_db, dialect) -> None: + with pytest.raises(SchemaError) as ei: + TableBuilder(fake_db, "users", dialect=dialect).integer("id; DROP TABLE users") + assert ei.value.code is error_codes.INVALID_IDENTIFIER + + +def test_build_without_columns_raises(fake_db, dialect) -> None: + with pytest.raises(SchemaError) as ei: + TableBuilder(fake_db, "users", dialect=dialect).build() + assert ei.value.code is error_codes.NO_COLUMNS_DEFINED + + +def test_foreign_key_requires_all_params(fake_db, dialect) -> None: + with pytest.raises(MissingDataError) as ei: + TableBuilder(fake_db, "orders", dialect=dialect).integer("user_id").foreign() + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +def test_foreign_key_renders(fake_db, dialect) -> None: + sql = ( + TableBuilder(fake_db, "orders", dialect=dialect) + .integer("user_id") + .foreign("fk_user", "users", "id") + .to_sql() + ) + assert "CONSTRAINT `fk_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)" in sql + + +def test_default_value_escaping(fake_db, dialect) -> None: + sql = ( + TableBuilder(fake_db, "t", dialect=dialect) + .string("name") + .default("O'Brien") + .to_sql() + ) + assert "DEFAULT 'O''Brien'" in sql