diff --git a/shiba/__init__.py b/shiba/__init__.py index a4ce4a8..873fe47 100644 --- a/shiba/__init__.py +++ b/shiba/__init__.py @@ -35,6 +35,7 @@ SchemaError, ShibaError, ) +from shiba.orm import Model, fields, set_default_connection if TYPE_CHECKING: from types import TracebackType @@ -109,6 +110,7 @@ def raw( "Database", "IntegrityError", "MissingDataError", + "Model", "MySQLDialect", "QueryBuilder", "QueryError", @@ -117,4 +119,6 @@ def raw( "ShibaError", "TableBuilder", "error_codes", + "fields", + "set_default_connection", ] diff --git a/shiba/error_codes.py b/shiba/error_codes.py index 555619d..ef5198d 100644 --- a/shiba/error_codes.py +++ b/shiba/error_codes.py @@ -32,7 +32,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from shiba.errors import ( ConnectionError, @@ -64,7 +64,7 @@ def build(self, message: str | None = None, **kwargs: Any) -> ShibaError: msg = message or self.default_message return self.exception_class(msg, code=self, **kwargs) - def raise_(self, message: str | None = None, **kwargs: Any) -> None: + def raise_(self, message: str | None = None, **kwargs: Any) -> NoReturn: """Lanza la excepción asociada con este código. Para ``QueryError`` y descendientes acepta ``query=``, ``params=``, diff --git a/shiba/identifiers.py b/shiba/identifiers.py index 964b65d..3f2083d 100644 --- a/shiba/identifiers.py +++ b/shiba/identifiers.py @@ -55,8 +55,8 @@ def validate_identifier(name: str, *, kind: str = "identifier") -> str: def validate_operator(op: str) -> str: """Acepta el operador o lanza :class:`SchemaError`.""" - if not isinstance(op, str): # defensa para callers no tipados - error_codes.INVALID_OPERATOR.raise_( # type: ignore[unreachable] + if not isinstance(op, str): # defensa para callers no tipados # type: ignore[unreachable] + error_codes.INVALID_OPERATOR.raise_( f"operador no string: {op!r}", details={"value": repr(op)}, ) diff --git a/shiba/orm/__init__.py b/shiba/orm/__init__.py new file mode 100644 index 0000000..ecaf73a --- /dev/null +++ b/shiba/orm/__init__.py @@ -0,0 +1,40 @@ +"""ORM tipado para Shiba. + +Uso mínimo: + +.. code-block:: python + + import shiba + from shiba.orm import Model, fields + + class User(Model): + __table__ = "users" + + id: int = fields.PrimaryKey() + name: str + email: str = fields.String(unique=True) + age: int | None = None + + shiba.set_default_connection(cx) + + User.create_table() + user = User(name="John", email="j@x.com", age=30) + user.save() + User.find(1) + User.where("age", ">", 18).get() +""" +from shiba.orm import fields +from shiba.orm.model import ( + Model, + ModelQuery, + get_default_connection, + set_default_connection, +) + +__all__ = [ + "Model", + "ModelQuery", + "fields", + "get_default_connection", + "set_default_connection", +] diff --git a/shiba/orm/fields.py b/shiba/orm/fields.py new file mode 100644 index 0000000..ebf7ef9 --- /dev/null +++ b/shiba/orm/fields.py @@ -0,0 +1,305 @@ +"""Field descriptors para modelos Shiba. + +Los modelos del ORM combinan anotaciones Python (que dan el tipo) con +instancias de :class:`Field` (que dan metadata extra: unique, default, +auto-increment, etc.). Cuando un atributo sólo tiene anotación, el +:class:`Field` se infiere automáticamente con defaults razonables. +""" +from __future__ import annotations + +import json +import types +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import date, datetime +from decimal import Decimal +from typing import Any, Union, get_args, get_origin + +from shiba import error_codes + +_UNSET: Any = object() +"""Sentinel para distinguir 'no se pasó valor' de 'None explícito'.""" + + +@dataclass +class Field: + """Descriptor de columna SQL. + + Casi nunca se instancia directamente — se usan las subclases + semánticas (:class:`String`, :class:`Integer`, etc.) y la inferencia + desde anotaciones. + """ + + sql_type: str = "VARCHAR(255)" + nullable: bool = False + primary_key: bool = False + unique: bool = False + auto_increment: bool = False + default: Any = _UNSET + default_factory: Callable[[], Any] | None = None + column_name: str | None = None + foreign_key: tuple[str, str] | None = None + json: bool = False + enum_choices: tuple[str, ...] | None = None + indexed: bool = False + + # --- Conversión Python <-> DB -------------------------------------- + + def to_python(self, raw: Any) -> Any: + """Decodifica el valor que llega desde la fila SQL.""" + if raw is None: + return None + if self.json and isinstance(raw, (str, bytes)): + return json.loads(raw) + return raw + + def to_db(self, value: Any) -> Any: + """Codifica el valor antes de enviarlo a la base de datos.""" + if value is None: + return None + if self.json and not isinstance(value, (str, bytes)): + return json.dumps(value, default=str) + return value + + # --- Defaults ------------------------------------------------------- + + def has_default(self) -> bool: + return self.default is not _UNSET or self.default_factory is not None + + def get_default(self) -> Any: + if self.default_factory is not None: + return self.default_factory() + if self.default is _UNSET: + return None + return self.default + + # --- DDL ------------------------------------------------------------ + + def apply_to_table_builder(self, tb: Any, column_name: str) -> None: + """Replica este campo en el ``TableBuilder``.""" + col = self.column_name or column_name + # Tipo + if self.primary_key and self.auto_increment and self.sql_type == "INT": + tb.increments(col, primary_key=True) + elif self.enum_choices is not None: + tb.enum(col, list(self.enum_choices)) + else: + # Reusar map de tipos via raw type string. + type_lower = self.sql_type.lower() + if type_lower.startswith("varchar"): + length = int(self.sql_type[8:-1]) if "(" in self.sql_type else 255 + tb.string(col, length) + elif type_lower == "text": + tb.text(col) + elif type_lower.startswith("int"): + tb.integer(col) + elif type_lower == "bigint": + tb.big_integer(col) + elif type_lower == "smallint": + tb.small_integer(col) + elif type_lower == "tinyint": + tb.tiny_integer(col) + elif type_lower == "boolean": + tb.boolean(col) + elif type_lower == "json": + tb.json(col) + elif type_lower == "datetime": + tb.datetime(col) + elif type_lower == "date": + tb.date(col) + elif type_lower == "time": + tb.time(col) + elif type_lower == "timestamp": + tb.timestamp(col) + elif type_lower.startswith("decimal"): + tb.decimal(col) + elif type_lower.startswith("float") or type_lower == "double": + tb.floats(col) + elif type_lower == "blob": + tb.binary(col) + else: + error_codes.UNSUPPORTED_TYPE.raise_( + f"tipo {self.sql_type!r} no soportado por apply_to_table_builder." + ) + # Constraints adicionales — añadidas al último column declared. + if self.primary_key and not self.auto_increment: + tb.primary() + if self.unique: + tb.unique() + if self.nullable: + tb.nullable() + elif not self.primary_key: + tb.not_nullable() + if self.default is not _UNSET and self.default_factory is None: + tb.default(self.default) + if self.foreign_key is not None: + ftable, fcolumn = self.foreign_key + tb.foreign(f"fk_{col}_{ftable}", ftable, fcolumn) + + +# --------------------------------------------------------------------------- +# Subclases semánticas +# --------------------------------------------------------------------------- + + +@dataclass +class PrimaryKey(Field): + sql_type: str = "INT" + primary_key: bool = True + auto_increment: bool = True + + +@dataclass +class String(Field): + max_length: int = 255 + + def __post_init__(self) -> None: + self.sql_type = f"VARCHAR({self.max_length})" + + +@dataclass +class Text(Field): + sql_type: str = "TEXT" + + +@dataclass +class Integer(Field): + sql_type: str = "INT" + + +@dataclass +class BigInteger(Field): + sql_type: str = "BIGINT" + + +@dataclass +class Boolean(Field): + sql_type: str = "BOOLEAN" + + def to_python(self, raw: Any) -> Any: + if raw is None: + return None + return bool(raw) + + +@dataclass +class FloatField(Field): + sql_type: str = "FLOAT" + + +@dataclass +class DecimalField(Field): + sql_type: str = "DECIMAL" + + def to_python(self, raw: Any) -> Any: + if raw is None: + return None + return Decimal(str(raw)) + + +@dataclass +class DateTime(Field): + sql_type: str = "DATETIME" + default_now: bool = False + + def __post_init__(self) -> None: + if self.default_now and self.default_factory is None: + self.default_factory = datetime.now + + def to_python(self, raw: Any) -> Any: + if raw is None or isinstance(raw, datetime): + return raw + return datetime.fromisoformat(str(raw)) + + +@dataclass +class DateField(Field): + sql_type: str = "DATE" + + def to_python(self, raw: Any) -> Any: + if raw is None or isinstance(raw, date): + return raw + return date.fromisoformat(str(raw)) + + +@dataclass +class Json(Field): + sql_type: str = "JSON" + json: bool = True + + +@dataclass +class Enum(Field): + choices: tuple[str, ...] = field(default_factory=tuple) + + def __post_init__(self) -> None: + if not self.choices: + error_codes.MISSING_REQUIRED_DATA.raise_( + "Enum requiere choices." + ) + self.sql_type = "ENUM" + self.enum_choices = self.choices + + +@dataclass +class ForeignKey(Field): + """``ForeignKey(to='users', column='id')``.""" + + to: str = "" + column: str = "id" + sql_type: str = "INT" + + def __post_init__(self) -> None: + if not self.to: + error_codes.MISSING_REQUIRED_DATA.raise_( + "ForeignKey requiere `to` (tabla destino)." + ) + self.foreign_key = (self.to, self.column) + + +# --------------------------------------------------------------------------- +# Inferencia desde anotaciones +# --------------------------------------------------------------------------- + + +def _is_optional(hint: Any) -> tuple[bool, Any]: + """Devuelve ``(nullable, inner_type)`` para ``X | None``.""" + origin = get_origin(hint) + if origin in (Union, types.UnionType): + args = [a for a in get_args(hint) if a is not type(None)] + if len(get_args(hint)) > len(args): + return True, args[0] if len(args) == 1 else hint + return False, hint + + +def infer_field(hint: Any, default: Any = _UNSET) -> Field: + """Construye un :class:`Field` a partir de la anotación del atributo.""" + nullable, inner = _is_optional(hint) + has_default = default is not _UNSET + f: Field + + if inner is str: + f = String() + elif inner is int: + f = Integer() + elif inner is bool: + f = Boolean() + elif inner is float: + f = FloatField() + elif inner is Decimal: + f = DecimalField() + elif inner is datetime: + f = DateTime() + elif inner is date: + f = DateField() + elif inner is bytes: + f = Field(sql_type="BLOB") + elif inner is dict or get_origin(inner) is dict or inner is list or get_origin(inner) is list: + f = Json() + else: + f = Field() + + f.nullable = nullable or (has_default and default is None) + if has_default: + f.default = default + return f diff --git a/shiba/orm/model.py b/shiba/orm/model.py new file mode 100644 index 0000000..ced267c --- /dev/null +++ b/shiba/orm/model.py @@ -0,0 +1,314 @@ +"""Modelos POO con metaclass que lee anotaciones.""" +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar + +from shiba import error_codes +from shiba.orm.fields import _UNSET, Field, infer_field + +if TYPE_CHECKING: + from shiba import ShibaConnection + from shiba.core.query_builder import QueryBuilder + + +T = TypeVar("T", bound="Model") + + +_default_connection: ShibaConnection | None = None + + +def set_default_connection(connection: ShibaConnection) -> None: + """Registra la conexión global usada por modelos sin ``__db__``.""" + global _default_connection + _default_connection = connection + + +def get_default_connection() -> ShibaConnection | None: + return _default_connection + + +class ModelMeta(type): + """Metaclass que extrae ``_fields`` desde las anotaciones de la clase.""" + + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + ns: dict[str, Any], + ) -> ModelMeta: + cls = super().__new__(mcs, name, bases, ns) + + # No procesamos la propia clase base ``Model``. + if ns.get("__shiba_model_root__", False): + return cls + + # Recolectamos anotaciones por clase del MRO, evaluando strings + # con eval_str. Esto evita fallar si una clase ancestra tiene + # forward refs no resolvibles (p. ej. ``ShibaConnection``). + annotations: dict[str, Any] = {} + for klass in reversed(cls.__mro__): + if klass is object: + continue + try: + hints = inspect.get_annotations(klass, eval_str=True) + except (NameError, AttributeError): + hints = inspect.get_annotations(klass, eval_str=False) + annotations.update(hints) + + fields: dict[str, Field] = {} + for attr, hint in annotations.items(): + if attr.startswith("_") or attr in {"ClassVar"}: + continue + value = ns.get(attr, _UNSET) + fld = value if isinstance(value, Field) else infer_field(hint, default=value) + fields[attr] = fld + + # Quitamos el Field del namespace para que la lookup pase por + # la instancia y no devuelva el descriptor. + if attr in cls.__dict__ and isinstance(cls.__dict__[attr], Field): + delattr(cls, attr) + + cls._fields = fields # type: ignore[attr-defined] + cls._table = ns.get("__table__", name.lower()) # type: ignore[attr-defined] + return cls + + +class Model(metaclass=ModelMeta): + """Base de cualquier modelo. Se levanta como objeto Python plano.""" + + __shiba_model_root__: ClassVar[bool] = True + _fields: ClassVar[dict[str, Field]] = {} + _table: ClassVar[str] = "" + __db__: ClassVar[Any] = None # ShibaConnection | None — tipado en docstring + + # ------------------------------------------------------------------ + # Construcción + # ------------------------------------------------------------------ + + def __init__(self, **kwargs: Any) -> None: + unknown = set(kwargs) - set(self._fields) + if unknown: + error_codes.INVALID_DATA_FORMAT.raise_( + f"{type(self).__name__}: claves desconocidas {sorted(unknown)}", + details={"unknown": sorted(unknown)}, + ) + for attr, fld in self._fields.items(): + if attr in kwargs: + setattr(self, attr, kwargs[attr]) + elif fld.has_default(): + setattr(self, attr, fld.get_default()) + else: + setattr(self, attr, None) + + def __repr__(self) -> str: + pairs = ", ".join(f"{k}={getattr(self, k, None)!r}" for k in self._fields) + return f"{type(self).__name__}({pairs})" + + def __eq__(self, other: object) -> bool: + if type(self) is not type(other): + return NotImplemented + return self.to_dict() == other.to_dict() # type: ignore[attr-defined,no-any-return] + + def __hash__(self) -> int: # pragma: no cover - identidad por pk + pk_attr = self._pk_attr() + return hash((type(self).__name__, getattr(self, pk_attr, None))) + + # ------------------------------------------------------------------ + # Introspección + # ------------------------------------------------------------------ + + @classmethod + def _pk_attr(cls) -> str: + for attr, fld in cls._fields.items(): + if fld.primary_key: + return attr + error_codes.MISSING_REQUIRED_DATA.raise_( + f"{cls.__name__}: ningún campo marcado como primary_key." + ) + + def to_dict(self) -> dict[str, Any]: + return {attr: getattr(self, attr, None) for attr in self._fields} + + def to_db_dict(self, *, exclude_pk_if_none: bool = True) -> dict[str, Any]: + """Devuelve el dict listo para INSERT/UPDATE.""" + pk_attr = self._pk_attr() + out: dict[str, Any] = {} + for attr, fld in self._fields.items(): + value = getattr(self, attr, None) + if ( + attr == pk_attr + and exclude_pk_if_none + and (value is None or value == _UNSET) + ): + continue + col = fld.column_name or attr + out[col] = fld.to_db(value) + return out + + @classmethod + def from_row(cls: type[T], row: dict[str, Any]) -> T: + """Hidrata una instancia desde una fila ``dict``.""" + instance = cls.__new__(cls) + for attr, fld in cls._fields.items(): + col = fld.column_name or attr + raw = row.get(col) + setattr(instance, attr, fld.to_python(raw)) + return instance + + # ------------------------------------------------------------------ + # Conexión + # ------------------------------------------------------------------ + + @classmethod + def _connection(cls) -> ShibaConnection: + conn: ShibaConnection | None = cls.__db__ or _default_connection + if conn is None: + error_codes.CONNECTION_NOT_OPEN.raise_( + f"{cls.__name__} no tiene conexión. Llama a " + "shiba.set_default_connection(cx) o asigna `__db__`." + ) + return conn + + # ------------------------------------------------------------------ + # Query API a nivel de clase + # ------------------------------------------------------------------ + + @classmethod + def query(cls: type[T]) -> ModelQuery[T]: + return ModelQuery(cls) + + @classmethod + def all(cls: type[T]) -> list[T]: + return cls.query().get() + + @classmethod + def find(cls: type[T], pk_value: Any) -> T | None: + row = cls._connection().table(cls._table).find(pk_value, pk=cls._pk_attr()) + return cls.from_row(row) if row else None + + @classmethod + def where(cls: type[T], *args: Any) -> ModelQuery[T]: + return cls.query().where(*args) + + @classmethod + def first(cls: type[T]) -> T | None: + return cls.query().first() + + @classmethod + def count(cls) -> int: + return cls._connection().table(cls._table).count() + + # ------------------------------------------------------------------ + # Schema + # ------------------------------------------------------------------ + + @classmethod + def create_table(cls) -> None: + tb = cls._connection().create_table(cls._table) + for attr, fld in cls._fields.items(): + fld.apply_to_table_builder(tb, attr) + tb.build() + + @classmethod + def drop_table(cls) -> None: + cx = cls._connection() + cx.raw(f"DROP TABLE IF EXISTS {cx.dialect.quote_identifier(cls._table)}") + + @classmethod + def truncate_table(cls) -> None: + cls._connection().table(cls._table).truncate() + + # ------------------------------------------------------------------ + # Persistencia + # ------------------------------------------------------------------ + + def save(self: T) -> T: + pk_attr = self._pk_attr() + pk_val = getattr(self, pk_attr, None) + cx = self._connection() + data = self.to_db_dict(exclude_pk_if_none=True) + if pk_val is None: + cx.table(self._table).insert(data) + new_id = cx.raw("SELECT LAST_INSERT_ID() AS v") + if new_id and new_id[0].get("v"): + setattr(self, pk_attr, new_id[0]["v"]) + else: + cx.table(self._table).where(pk_attr, pk_val).update(data) + return self + + def delete(self) -> None: + pk_attr = self._pk_attr() + pk_val = getattr(self, pk_attr, None) + if pk_val is None: + error_codes.MISSING_REQUIRED_DATA.raise_( + f"delete(): {type(self).__name__} sin PK." + ) + self._connection().table(self._table).where(pk_attr, pk_val).delete() + + +# --------------------------------------------------------------------------- +# Manager hidratante +# --------------------------------------------------------------------------- + + +class ModelQuery(Generic[T]): + """Wrapper de :class:`QueryBuilder` que devuelve modelos en vez de dicts.""" + + def __init__(self, model_cls: type[T]) -> None: + self.model_cls = model_cls + cx = model_cls._connection() + self._qb: QueryBuilder = cx.table(model_cls._table) + + # Delegación al builder con retorno fluido. + def where(self, *args: Any) -> ModelQuery[T]: + self._qb.where(*args) + return self + + def or_where(self, *args: Any) -> ModelQuery[T]: + self._qb.or_where(*args) + return self + + def where_in(self, column: str, values: list[Any]) -> ModelQuery[T]: + self._qb.where_in(column, values) + return self + + def where_null(self, column: str) -> ModelQuery[T]: + self._qb.where_null(column) + return self + + def where_not_null(self, column: str) -> ModelQuery[T]: + self._qb.where_not_null(column) + return self + + def order_by(self, column: str, direction: str = "ASC") -> ModelQuery[T]: + self._qb.order_by(column, direction) + return self + + def limit(self, n: int) -> ModelQuery[T]: + self._qb.limit(n) + return self + + def offset(self, n: int) -> ModelQuery[T]: + self._qb.offset(n) + return self + + # Ejecución hidratada. + def get(self) -> list[T]: + rows = self._qb.get() + return [self.model_cls.from_row(r) for r in rows] + + def first(self) -> T | None: + row = self._qb.first() + return self.model_cls.from_row(row) if row else None + + def count(self) -> int: + return self._qb.count() + + def exists(self) -> bool: + return self._qb.exists() + + def paginate(self, page: int = 1, per_page: int = 25) -> dict[str, Any]: + result = self._qb.paginate(page, per_page) + result["data"] = [self.model_cls.from_row(r) for r in result["data"]] + return result diff --git a/tests/test_orm.py b/tests/test_orm.py new file mode 100644 index 0000000..d122265 --- /dev/null +++ b/tests/test_orm.py @@ -0,0 +1,293 @@ +"""Cobertura del ORM (fields + Model + ModelQuery).""" +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal +from typing import Any + +import pytest + +from shiba import Model, error_codes, fields, set_default_connection +from shiba.core.query_builder import QueryBuilder +from shiba.core.table_builder import TableBuilder +from shiba.dialects.mysql import MySQLDialect +from shiba.errors import MissingDataError, ShibaError + +# --------------------------------------------------------------------------- +# Fake connection que reusa el FakeDatabase de conftest +# --------------------------------------------------------------------------- + + +class FakeShibaConnection: + def __init__(self, fake_db: Any) -> None: + self.db = fake_db + self.dialect = MySQLDialect() + + def table(self, name: str) -> QueryBuilder: + return QueryBuilder(self.db, name, dialect=self.dialect) + + def create_table(self, name: str) -> TableBuilder: + return TableBuilder(self.db, name, dialect=self.dialect) + + def raw(self, query: str, params: Any = None, *, many: bool = False): + return self.db.execute(query, params, many=many) + + +@pytest.fixture +def cx(fake_db) -> FakeShibaConnection: + conn = FakeShibaConnection(fake_db) + set_default_connection(conn) # type: ignore[arg-type] + return conn + + +# --------------------------------------------------------------------------- +# Field inference +# --------------------------------------------------------------------------- + + +def test_infer_types_from_annotations() -> None: + class T(Model): + __table__ = "t" + id: int = fields.PrimaryKey() + name: str + age: int | None = None + active: bool = True + score: float = 0.0 + amount: Decimal | None = None + settings: dict = fields.Json(default_factory=dict) + notes: str | None = None + created_at: datetime = fields.DateTime(default_now=True) + + f = T._fields + assert f["id"].primary_key and f["id"].auto_increment + assert f["name"].sql_type == "VARCHAR(255)" and not f["name"].nullable + assert f["age"].sql_type == "INT" and f["age"].nullable + assert f["active"].sql_type == "BOOLEAN" and f["active"].default is True + assert f["score"].sql_type == "FLOAT" + assert f["amount"].sql_type == "DECIMAL" and f["amount"].nullable + assert f["settings"].json and f["settings"].sql_type == "JSON" + assert f["notes"].nullable + assert f["created_at"].default_factory is not None + + +def test_unknown_kwarg_rejected(cx) -> None: + class T(Model): + __table__ = "t" + id: int = fields.PrimaryKey() + name: str + + with pytest.raises(ShibaError) as ei: + T(name="x", oops=1) + assert ei.value.code is error_codes.INVALID_DATA_FORMAT + + +def test_pk_required(cx) -> None: + class NoPk(Model): + __table__ = "nopk" + name: str + + with pytest.raises(MissingDataError) as ei: + NoPk(name="x").save() + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +# --------------------------------------------------------------------------- +# Persistencia +# --------------------------------------------------------------------------- + + +def test_save_insert(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + age: int | None = None + + # Para que save() interprete LAST_INSERT_ID le damos un retorno. + seq = iter([[], [{"v": 42}]]) + + def execute(query, params=None, **kwargs): + fake_db.calls.append((query, params, kwargs.get("many", False))) + return next(seq) + + fake_db.execute = execute # type: ignore[method-assign] + + u = User(name="John", age=30) + u.save() + assert u.id == 42 + insert_sql, params, _ = fake_db.calls[0] + assert insert_sql.startswith("INSERT INTO `users` (`name`, `age`)") + assert params == ("John", 30) + + +def test_save_update_when_pk_present(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + u = User(id=5, name="John") + u.save() + sql, params, _ = fake_db.last_call + assert sql == "UPDATE `users` SET `id` = %s, `name` = %s WHERE `id` = %s" + assert params == (5, "John", 5) + + +def test_delete(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + User(id=7, name="X").delete() + sql, params, _ = fake_db.last_call + assert sql == "DELETE FROM `users` WHERE `id` = %s" + assert params == (7,) + + +def test_delete_without_pk_raises(cx) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + with pytest.raises(MissingDataError) as ei: + User(name="X").delete() + assert ei.value.code is error_codes.MISSING_REQUIRED_DATA + + +# --------------------------------------------------------------------------- +# Lectura +# --------------------------------------------------------------------------- + + +def test_find_returns_model_instance(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + fake_db.result = [{"id": 3, "name": "Alice"}] + u = User.find(3) + assert isinstance(u, User) + assert u.id == 3 + assert u.name == "Alice" + + +def test_find_returns_none_when_missing(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + fake_db.result = [] + assert User.find(99) is None + + +def test_where_returns_modelquery(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + age: int + + fake_db.result = [ + {"id": 1, "name": "A", "age": 20}, + {"id": 2, "name": "B", "age": 30}, + ] + rows = User.where("age", ">", 10).order_by("age").get() + assert all(isinstance(r, User) for r in rows) + assert [r.name for r in rows] == ["A", "B"] + sql, params, _ = fake_db.last_call + assert "WHERE `age` > %s" in sql + assert "ORDER BY `age` ASC" in sql + assert params == (10,) + + +def test_json_field_roundtrips(cx, fake_db) -> None: + class Doc(Model): + __table__ = "docs" + id: int = fields.PrimaryKey() + payload: dict = fields.Json(default_factory=dict) + + fake_db.result = [{"id": 1, "payload": '{"a": 1}'}] + d = Doc.find(1) + assert d is not None + assert d.payload == {"a": 1} + + # Y al guardar, payload se serializa. + fake_db.result = [] + Doc(payload={"x": "y"}).save() + insert_calls = [c for c in fake_db.calls if c[0].startswith("INSERT INTO `docs`")] + assert insert_calls, "esperaba al menos un INSERT" + sql, params, _ = insert_calls[-1] + assert params == ('{"x": "y"}',) + + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + + +def test_create_table_emits_ddl(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str = fields.String(max_length=50) + email: str = fields.String(unique=True) + age: int | None = None + settings: dict = fields.Json(default_factory=dict) + + User.create_table() + sql, _, _ = fake_db.last_call + assert "CREATE TABLE IF NOT EXISTS `users`" in sql + assert "`id` INT AUTO_INCREMENT PRIMARY KEY" in sql + assert "`name` VARCHAR(50) NOT NULL" in sql + assert "`email` VARCHAR(255) UNIQUE NOT NULL" in sql + assert "`age` INT NULL" in sql + assert "`settings` JSON NOT NULL" in sql + + +def test_truncate_table(cx, fake_db) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + User.truncate_table() + sql, _, _ = fake_db.last_call + assert sql == "TRUNCATE TABLE `users`" + + +def test_default_table_name_from_class(cx) -> None: + class Customer(Model): + id: int = fields.PrimaryKey() + name: str + + assert Customer._table == "customer" + + +def test_to_dict_and_from_row_roundtrip(cx) -> None: + class User(Model): + __table__ = "users" + id: int = fields.PrimaryKey() + name: str + + u = User.from_row({"id": 1, "name": "X"}) + assert u.to_dict() == {"id": 1, "name": "X"} + + +def test_no_connection_raises(monkeypatch) -> None: + from shiba.orm import model as model_mod + + monkeypatch.setattr(model_mod, "_default_connection", None) + + class T(Model): + __table__ = "t" + id: int = fields.PrimaryKey() + name: str + + T.__db__ = None + with pytest.raises(ShibaError) as ei: + T.find(1) + assert ei.value.code is error_codes.CONNECTION_NOT_OPEN