From 72afcb737d92c298aeeb4d80740097d053d05f56 Mon Sep 17 00:00:00 2001 From: Ilia Ablamonov Date: Wed, 11 Mar 2026 18:15:23 +0100 Subject: [PATCH 1/2] Add pool_options argument --- example/config.py | 3 + example/db/mydb.py | 4 +- example/generate.py | 1 + src/iron_sql/__init__.py | 2 + src/iron_sql/codegen/generator.py | 99 +++++++++++++++---- src/iron_sql/runtime.py | 38 +++++++- tests/test_code_generation.py | 156 ++++++++++++++++++++++++++++++ tests/test_runtime.py | 17 ++++ 8 files changed, 296 insertions(+), 24 deletions(-) diff --git a/example/config.py b/example/config.py index 0d119b6..0ecb315 100644 --- a/example/config.py +++ b/example/config.py @@ -1,3 +1,6 @@ import os +from iron_sql import PoolOptions + DSN = os.environ.get("DATABASE_URL", "") +POOL_OPTIONS: PoolOptions = {"min_size": 1, "max_size": 10, "timeout": 15.0} diff --git a/example/db/mydb.py b/example/db/mydb.py index 56974cd..88a99d8 100644 --- a/example/db/mydb.py +++ b/example/db/mydb.py @@ -27,13 +27,15 @@ from iron_sql import runtime from example.config import DSN - +from example.config import POOL_OPTIONS import example.models + MYDB_POOL = runtime.ConnectionPool( DSN, name="mydb", application_name=None, + pool_options=POOL_OPTIONS, ) _mydb_connection = ContextVar[psycopg.AsyncConnection | None]( diff --git a/example/generate.py b/example/generate.py index 93c8e78..cee54c3 100644 --- a/example/generate.py +++ b/example/generate.py @@ -20,6 +20,7 @@ def generate_db_package(dsn: str, schema_path: Path, src_path: Path) -> bool: schema_path=schema_path, package_full_name="example.db.mydb", dsn_import="example.config:DSN", + pool_options_import="example.config:POOL_OPTIONS", src_path=src_path, json_model_overrides={ "projects.settings": "example.models:ProjectSettings", diff --git a/src/iron_sql/__init__.py b/src/iron_sql/__init__.py index 7eac207..0932400 100644 --- a/src/iron_sql/__init__.py +++ b/src/iron_sql/__init__.py @@ -1,9 +1,11 @@ """iron_sql: Typed SQL client generator for Python.""" from iron_sql.runtime import NoRowsError +from iron_sql.runtime import PoolOptions from iron_sql.runtime import TooManyRowsError __all__ = [ "NoRowsError", + "PoolOptions", "TooManyRowsError", ] diff --git a/src/iron_sql/codegen/generator.py b/src/iron_sql/codegen/generator.py index 750235b..b0049c3 100644 --- a/src/iron_sql/codegen/generator.py +++ b/src/iron_sql/codegen/generator.py @@ -71,6 +71,50 @@ class UnknownSQLTypeWarning(UserWarning): pass +@dataclass(kw_only=True, frozen=True) +class ModuleExprRef: + module_name: str + module_expr: str + + @classmethod + def parse(cls, value: str) -> "ModuleExprRef": + module_name, sep, module_expr = value.partition(":") + if not sep: + msg = f"module expression must be 'module:expr', got: {value!r}" + raise ValueError(msg) + match = re.match(r"[A-Za-z_][A-Za-z0-9_]*", module_expr) + if match is None: + msg = ( + "module expression must start with identifier, " + f"got: {module_expr!r}" + ) + raise ValueError(msg) + return cls(module_name=module_name, module_expr=module_expr) + + @property + def import_name(self) -> str: + match = re.match(r"[A-Za-z_][A-Za-z0-9_]*", self.module_expr) + if match is None: + msg = ( + "module expression must start with identifier, " + f"got: {self.module_expr!r}" + ) + raise ValueError(msg) + return match.group() + + def evaluate[T](self, *, expected_type: type[T]) -> T: + mod = importlib.import_module(self.module_name) + value = eval(self.module_expr, vars(mod)) # noqa: S307 + if not isinstance(value, expected_type): + msg = ( + f"module expression {self.module_name}:{self.module_expr} " + f"must evaluate to " + f"{expected_type.__name__}, got: {type(value).__name__}" + ) + raise TypeError(msg) + return value + + _SQL_TYPE_MAP: dict[str, str] = { "bool": "bool", "boolean": "bool", @@ -210,6 +254,7 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 schema_path: Path, package_full_name: str, dsn_import: str, + pool_options_import: str | None = None, application_name: str | None = None, type_overrides: dict[str, str] | None = None, json_model_overrides: dict[str, str] | None = None, @@ -224,7 +269,15 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 queries, all_locations = collect_queries(src_path, sql_fn_name) - dsn, dsn_import_package, dsn_import_path = resolve_dsn(dsn_import) + dsn_ref = ModuleExprRef.parse(dsn_import) + dsn = dsn_ref.evaluate(expected_type=str) + pool_options_ref = ( + ModuleExprRef.parse(pool_options_import) + if pool_options_import is not None + else None + ) + if pool_options_ref is not None: + pool_options_ref.evaluate(expected_type=dict) sqlc_res, block_starts = run_sqlc( src_path / schema_path, @@ -283,8 +336,7 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py" new_content = render_package( - dsn_import_package, - dsn_import_path, + dsn_ref, package_name, sql_fn_name, entities, @@ -294,6 +346,7 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 query_dict_entries, application_name, json_import_block, + pool_options_ref, ) changed = write_if_changed(target_package_path, new_content + "\n") if changed: @@ -316,13 +369,6 @@ def collect_queries( return queries, all_locations -def resolve_dsn(dsn_import: str) -> tuple[str, str, str]: - package_name, attr_path = dsn_import.split(":") - mod = importlib.import_module(package_name) - dsn: str = eval(attr_path, vars(mod)) # noqa: S307 - return dsn, package_name, attr_path - - def render_query_classes( sqlc_queries: tuple[Query, ...], queries: list["CodeQuery"], @@ -414,8 +460,7 @@ def resolve_json_model_overrides( def render_package( # noqa: PLR0913, PLR0917 - dsn_import_package: str, - dsn_import_path: str, + dsn_ref: ModuleExprRef, package_name: str, sql_fn_name: str, entities: list[str], @@ -425,7 +470,27 @@ def render_package( # noqa: PLR0913, PLR0917 query_dict_entries: list[str], application_name: str | None = None, json_import_block: str = "", -): + pool_options_ref: ModuleExprRef | None = None, +) -> str: + imports = [f"from {dsn_ref.module_name} import {dsn_ref.import_name}"] + pool_args = [ + dsn_ref.module_expr, + f'name="{package_name}"', + f"application_name={application_name!r}", + ] + if pool_options_ref is not None: + imports.append( + f"from {pool_options_ref.module_name} import {pool_options_ref.import_name}" + ) + pool_args.append(f"pool_options={pool_options_ref.module_expr}") + + if json_import_block: + imports.extend(json_import_block.strip().splitlines()) + + imports_block = "\n".join(imports) + + pool_args_str = ",\n ".join(pool_args) + return f""" # Code generated by iron_sql, DO NOT EDIT. @@ -456,13 +521,11 @@ def render_package( # noqa: PLR0913, PLR0917 from iron_sql import runtime -from {dsn_import_package} import {dsn_import_path.split(".", maxsplit=1)[0]} -{json_import_block} +{imports_block} + {package_name.upper()}_POOL = runtime.ConnectionPool( - {dsn_import_path}, - name="{package_name}", - application_name={application_name!r}, + {pool_args_str}, ) _{package_name}_connection = ContextVar[psycopg.AsyncConnection | None]( diff --git a/src/iron_sql/runtime.py b/src/iron_sql/runtime.py index 555251c..ce1dc3a 100644 --- a/src/iron_sql/runtime.py +++ b/src/iron_sql/runtime.py @@ -4,6 +4,7 @@ import types from collections.abc import AsyncGenerator from collections.abc import AsyncIterator +from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Sequence from contextlib import asynccontextmanager @@ -13,6 +14,7 @@ from typing import ClassVar from typing import Literal from typing import Self +from typing import TypedDict from typing import overload import psycopg @@ -162,6 +164,22 @@ async def _server_cursor( yield cur +class PoolOptions(TypedDict, total=False): + min_size: int + max_size: int | None + timeout: float + max_waiting: int + max_lifetime: float + max_idle: float + reconnect_timeout: float + num_workers: int + kwargs: dict[str, Any] + configure: Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]] + check: Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]] + reset: Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]] + reconnect_failed: Callable[[psycopg_pool.AsyncConnectionPool[Any]], Awaitable[None]] + + class ConnectionPool: def __init__( self, @@ -169,10 +187,12 @@ def __init__( *, name: str | None = None, application_name: str | None = None, + pool_options: PoolOptions | None = None, ) -> None: self.conninfo = conninfo self.name = name self.application_name = application_name + self.pool_options = pool_options or {} self._init_psycopg_pool() async def close(self) -> None: @@ -209,15 +229,23 @@ async def connection(self) -> AsyncIterator[psycopg.AsyncConnection]: yield conn def _init_psycopg_pool(self) -> None: + user_kwargs: dict[str, Any] = self.pool_options.get("kwargs", {}) + forwarded: dict[str, Any] = { + k: v for k, v in self.pool_options.items() if k != "kwargs" + } + conn_kwargs = { + **user_kwargs, + # https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions + "autocommit": True, + } + if self.application_name is not None: + conn_kwargs["application_name"] = self.application_name self.psycopg_pool = psycopg_pool.AsyncConnectionPool( self.conninfo, + **forwarded, open=False, name=self.name, - kwargs={ - "application_name": self.application_name, - # https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions - "autocommit": True, - }, + kwargs=conn_kwargs, ) @asynccontextmanager diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 9846c5e..71aeef6 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -2,11 +2,13 @@ import logging import re import sys +from ast import parse from pathlib import Path import pytest from iron_sql.codegen import generate_sql_package +from iron_sql.codegen.generator import ModuleExprRef from iron_sql.codegen.sqlc import run_sqlc from tests.conftest import ProjectBuilder @@ -215,6 +217,27 @@ async def test_special_types_params(test_project: ProjectBuilder) -> None: assert test_project.generate_no_import() is True +def test_module_expr_ref_parse_and_evaluate(test_project: ProjectBuilder) -> None: + (test_project.app_dir / "config.py").write_text( + f"""DSN = "{test_project.dsn}" + +def get_dsn() -> str: + return DSN +""", + encoding="utf-8", + ) + + if str(test_project.src_path) not in sys.path: + sys.path.insert(0, str(test_project.src_path)) + + expr_ref = ModuleExprRef.parse(f"{test_project.app_pkg}.config:get_dsn()") + + assert expr_ref.module_name == f"{test_project.app_pkg}.config" + assert expr_ref.module_expr == "get_dsn()" + assert expr_ref.import_name == "get_dsn" + assert expr_ref.evaluate(expected_type=str) == test_project.dsn + + def test_dsn_import_with_function_call(test_project: ProjectBuilder) -> None: (test_project.app_dir / "config.py").write_text( f""" @@ -249,6 +272,139 @@ def get_dsn(self) -> str: assert "CONFIG.get_dsn()" in generated +def test_dsn_import_with_factory_call_generates_valid_python( + test_project: ProjectBuilder, +) -> None: + (test_project.app_dir / "config.py").write_text( + f"""DSN = "{test_project.dsn}" + +def get_dsn() -> str: + return DSN +""", + encoding="utf-8", + ) + + test_project.add_query("q", "SELECT 1 as value") + + if str(test_project.src_path) not in sys.path: + sys.path.insert(0, str(test_project.src_path)) + + generate_sql_package( + schema_path=Path("schema.sql"), + package_full_name=test_project.pkg_name, + dsn_import=f"{test_project.app_pkg}.config:get_dsn()", + src_path=test_project.src_path, + tempdir_path=test_project.src_path, + ) + + generated_path = ( + test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + ) + generated = generated_path.read_text() + assert f"from {test_project.app_pkg}.config import get_dsn" in generated + assert "get_dsn()" in generated + parse(generated) + + +def test_pool_options_import(test_project: ProjectBuilder) -> None: + config = ( + f'DSN = "{test_project.dsn}"\nPOOL_OPTIONS = {{"min_size": 1, "max_size": 5}}\n' + ) + (test_project.app_dir / "config.py").write_text(config, encoding="utf-8") + + test_project.add_query("q", "SELECT 1 as value") + + if str(test_project.src_path) not in sys.path: + sys.path.insert(0, str(test_project.src_path)) + + generate_sql_package( + schema_path=Path("schema.sql"), + package_full_name=test_project.pkg_name, + dsn_import=f"{test_project.app_pkg}.config:DSN", + pool_options_import=f"{test_project.app_pkg}.config:POOL_OPTIONS", + src_path=test_project.src_path, + tempdir_path=test_project.src_path, + ) + + generated_path = ( + test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + ) + generated = generated_path.read_text() + assert f"from {test_project.app_pkg}.config import POOL_OPTIONS" in generated + assert "pool_options=POOL_OPTIONS" in generated + + +def test_pool_options_import_factory_generates_valid_python( + test_project: ProjectBuilder, +) -> None: + config = f"""DSN = "{test_project.dsn}" + +def get_pool_options() -> dict[str, object]: + return {{"min_size": 1, "max_size": 5}} +""" + (test_project.app_dir / "config.py").write_text(config, encoding="utf-8") + + test_project.add_query("q", "SELECT 1 as value") + + if str(test_project.src_path) not in sys.path: + sys.path.insert(0, str(test_project.src_path)) + + generate_sql_package( + schema_path=Path("schema.sql"), + package_full_name=test_project.pkg_name, + dsn_import=f"{test_project.app_pkg}.config:DSN", + pool_options_import=f"{test_project.app_pkg}.config:get_pool_options()", + src_path=test_project.src_path, + tempdir_path=test_project.src_path, + ) + + generated_path = ( + test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + ) + generated = generated_path.read_text() + assert f"from {test_project.app_pkg}.config import get_pool_options" in generated + assert "pool_options=get_pool_options()" in generated + parse(generated) + + +def test_pool_options_import_invalid_fails_during_generation( + test_project: ProjectBuilder, +) -> None: + config = f'DSN = "{test_project.dsn}"\n' + (test_project.app_dir / "config.py").write_text(config, encoding="utf-8") + + test_project.add_query("q", "SELECT 1 as value") + + if str(test_project.src_path) not in sys.path: + sys.path.insert(0, str(test_project.src_path)) + + with pytest.raises(NameError, match="MISSING_POOL_OPTIONS"): + generate_sql_package( + schema_path=Path("schema.sql"), + package_full_name=test_project.pkg_name, + dsn_import=f"{test_project.app_pkg}.config:DSN", + pool_options_import=f"{test_project.app_pkg}.config:MISSING_POOL_OPTIONS", + src_path=test_project.src_path, + tempdir_path=test_project.src_path, + ) + + generated_path = ( + test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + ) + assert not generated_path.exists() + + +def test_pool_options_import_not_set(test_project: ProjectBuilder) -> None: + test_project.add_query("q", "SELECT 1 as value") + test_project.generate_no_import() + + generated_path = ( + test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + ) + generated = generated_path.read_text() + assert "**" not in generated + + def test_package_is_marked_as_typed() -> None: assert importlib.resources.files("iron_sql").joinpath("py.typed").is_file() diff --git a/tests/test_runtime.py b/tests/test_runtime.py index db4ce83..d5f9817 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -121,6 +121,23 @@ async def test_typed_scalar_row_not_null_raises_on_none( await cur.fetchone() +async def test_pool_forwards_pool_options(pg_dsn: str) -> None: + async with ConnectionPool(pg_dsn, pool_options={"min_size": 1, "max_size": 2}) as p: + assert p.psycopg_pool.min_size == 1 + assert p.psycopg_pool.max_size == 2 + await p.check() + + +async def test_pool_preserves_application_name_from_pool_options_kwargs() -> None: + pool = ConnectionPool( + "postgresql://example.invalid/db", + pool_options={"kwargs": {"application_name": "from-pool-options"}}, + ) + assert isinstance(pool.psycopg_pool.kwargs, dict) + assert pool.psycopg_pool.kwargs["application_name"] == "from-pool-options" + await pool.psycopg_pool.close() + + async def test_typed_scalar_row_type_mismatch(pool: ConnectionPool) -> None: async with ( pool.connection() as conn, From 1b79779776d876b722a6926540fc329b2d6fca58 Mon Sep 17 00:00:00 2001 From: Ilia Ablamonov Date: Wed, 11 Mar 2026 20:08:52 +0100 Subject: [PATCH 2/2] Introduce better naming --- README.md | 14 +-- example/db/mydb.py | 30 ++--- example/generate.py | 16 +-- src/iron_sql/codegen/__init__.py | 4 +- src/iron_sql/codegen/generator.py | 192 ++++++++++++++--------------- src/iron_sql/codegen/sqlc.py | 10 +- tests/conftest.py | 14 +-- tests/test_code_generation.py | 64 +++++----- tests/test_json_model_overrides.py | 2 +- 9 files changed, 173 insertions(+), 173 deletions(-) diff --git a/README.md b/README.md index 30c2004..5703945 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![PyPI - Version](https://img.shields.io/pypi/v/iron-sql)](https://pypi.org/project/iron-sql/) -`iron_sql` is a typed SQL code generator and async runtime for PostgreSQL. Write SQL where you use it, run `generate_sql_package`, and get a module with typed dataclasses, query helpers, and pooled connections without hand-written boilerplate. +`iron_sql` is a typed SQL code generator and async runtime for PostgreSQL. Write SQL where you use it, run `generate_sql_module`, and get a module with typed dataclasses, query helpers, and pooled connections without hand-written boilerplate. ## Installation @@ -17,7 +17,7 @@ pip install iron-sql[codegen] # + inflection for code generation The `sqlc` binary is bundled automatically via the `sqlc` Python package. ## Key Features -- **Query discovery.** `generate_sql_package` scans your codebase for calls like `_sql("SELECT ...")`, runs `sqlc` for type analysis, and emits a typed module. +- **Query discovery.** `generate_sql_module` scans your codebase for calls like `_sql("SELECT ...")`, runs `sqlc` for type analysis, and emits a typed module. - **Strong typing.** Generated dataclasses and method signatures flow through your IDE and type checker. - **Async runtime.** Built on `psycopg` v3 with pooled connections, context-based connection reuse, and transaction helpers. - **Streaming.** `query_stream()` uses server-side cursors for memory-efficient iteration over large result sets. @@ -44,12 +44,12 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package. ```python from pathlib import Path - from iron_sql.codegen import generate_sql_package + from iron_sql.codegen import generate_sql_module - generate_sql_package( + generate_sql_module( schema_path=Path("schema.sql"), - package_full_name="myapp.db.mydb", - dsn_import="myapp.config:DSN", + module_full_name="myapp.db.mydb", + dsn_expr="myapp.config:DSN", src_path=Path("."), ) ``` @@ -66,7 +66,7 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package. - **Type overrides.** `type_overrides={"custom_type": "int"}` maps database type names to Python type strings. - **JSON model overrides.** `json_model_overrides={"users.metadata": "myapp.models:UserMeta"}` adds Pydantic validation for JSON/JSONB columns. - **Naming conventions.** Supply `to_pascal_fn` and `to_snake_fn` callables to control generated names. -- **DSN configuration.** `dsn_import` is written verbatim into the generated module; point it at a config variable, env var lookup, or function call. +- **Connection settings.** `dsn_expr` and `pool_options_expr` are written verbatim into the generated module; point them at config variables, env var lookups, or function calls. - **Debug artifacts.** Pass `debug_path` to save sqlc inputs and outputs for inspection. ## Runtime Highlights diff --git a/example/db/mydb.py b/example/db/mydb.py index 88a99d8..acc5adb 100644 --- a/example/db/mydb.py +++ b/example/db/mydb.py @@ -297,31 +297,31 @@ def query_stream(self, *, status: MydbTaskStatus) -> AbstractAsyncContextManager @overload -def mydb_sql(stmt: Literal['\n INSERT INTO users (id, username, email)\n VALUES (@id, @username, @email)\n ']) -> Query_3ee53b6909da8b4496346dda36c9f442: ... +def mydb_sql(sql: Literal['\n INSERT INTO users (id, username, email)\n VALUES (@id, @username, @email)\n ']) -> Query_3ee53b6909da8b4496346dda36c9f442: ... @overload -def mydb_sql(stmt: Literal['\n INSERT INTO projects (id, name, owner_id, settings)\n VALUES (@id, @name, @owner_id, @settings)\n ']) -> Query_67ac0768d48a654b1a305124c92372e8: ... +def mydb_sql(sql: Literal['\n INSERT INTO projects (id, name, owner_id, settings)\n VALUES (@id, @name, @owner_id, @settings)\n ']) -> Query_67ac0768d48a654b1a305124c92372e8: ... @overload -def mydb_sql(stmt: Literal['\n INSERT INTO tasks (id, project_id, title, priority, assignee_id, metadata, due_date)\n VALUES (@id, @project_id, @title, @priority, @assignee_id?, @metadata?, @due_date?)\n ']) -> Query_bd4c62c78a942bfd1f087f87a19f2743: ... +def mydb_sql(sql: Literal['\n INSERT INTO tasks (id, project_id, title, priority, assignee_id, metadata, due_date)\n VALUES (@id, @project_id, @title, @priority, @assignee_id?, @metadata?, @due_date?)\n ']) -> Query_bd4c62c78a942bfd1f087f87a19f2743: ... @overload -def mydb_sql(stmt: Literal['UPDATE tasks SET status = @status WHERE id = @task_id']) -> Query_12e061f7aa94bf484295ab0018520059: ... +def mydb_sql(sql: Literal['UPDATE tasks SET status = @status WHERE id = @task_id']) -> Query_12e061f7aa94bf484295ab0018520059: ... @overload -def mydb_sql(stmt: Literal['SELECT id, username, email, created_at FROM users ORDER BY created_at']) -> Query_46242a02ffe365dc17851a034fdc1d30: ... +def mydb_sql(sql: Literal['SELECT id, username, email, created_at FROM users ORDER BY created_at']) -> Query_46242a02ffe365dc17851a034fdc1d30: ... @overload -def mydb_sql(stmt: Literal['SELECT id, username, email, created_at FROM users WHERE id = @user_id']) -> Query_41cb2f3cea216a76ba87b6ddb70e6be5: ... +def mydb_sql(sql: Literal['SELECT id, username, email, created_at FROM users WHERE id = @user_id']) -> Query_41cb2f3cea216a76ba87b6ddb70e6be5: ... @overload -def mydb_sql(stmt: Literal["\n SELECT id, project_id, assignee_id, title, status, priority, metadata, due_date, created_at\n FROM tasks\n WHERE project_id = @project_id AND (sqlc.narg('status')::task_status IS NULL OR status = @status?)\n "]) -> Query_ce9822661c2a7e0e716755087929ebd9: ... +def mydb_sql(sql: Literal["\n SELECT id, project_id, assignee_id, title, status, priority, metadata, due_date, created_at\n FROM tasks\n WHERE project_id = @project_id AND (sqlc.narg('status')::task_status IS NULL OR status = @status?)\n "]) -> Query_ce9822661c2a7e0e716755087929ebd9: ... @overload -def mydb_sql(stmt: Literal['\n SELECT status, count(*) AS task_count\n FROM tasks WHERE project_id = @project_id\n GROUP BY status ORDER BY status\n '], row_type: Literal['TaskStatusCount']) -> Query_cabe6d4d91163f6aadc739bf765777db_TaskStatusCount: ... +def mydb_sql(sql: Literal['\n SELECT status, count(*) AS task_count\n FROM tasks WHERE project_id = @project_id\n GROUP BY status ORDER BY status\n '], row_type: Literal['TaskStatusCount']) -> Query_cabe6d4d91163f6aadc739bf765777db_TaskStatusCount: ... @overload -def mydb_sql(stmt: Literal['SELECT id FROM tasks WHERE project_id = @project_id AND title = @title']) -> Query_07cbb3e5226e35adbd17171f38ab7216: ... +def mydb_sql(sql: Literal['SELECT id FROM tasks WHERE project_id = @project_id AND title = @title']) -> Query_07cbb3e5226e35adbd17171f38ab7216: ... @overload -def mydb_sql(stmt: Literal['SELECT count(*) FROM tasks WHERE status = @status']) -> Query_29c838280e39383dd6b0760431eb3e60: ... +def mydb_sql(sql: Literal['SELECT count(*) FROM tasks WHERE status = @status']) -> Query_29c838280e39383dd6b0760431eb3e60: ... @overload -def mydb_sql(stmt: str) -> Query: ... +def mydb_sql(sql: str) -> Query: ... -def mydb_sql(stmt: str, row_type: str | None = None) -> Query: - if stmt in _QUERIES: - return _QUERIES[stmt]() - msg = f"Unknown statement: {stmt!r}" +def mydb_sql(sql: str, row_type: str | None = None) -> Query: + if sql in _QUERIES: + return _QUERIES[sql]() + msg = f"Unknown statement: {sql!r}" raise KeyError(msg) diff --git a/example/generate.py b/example/generate.py index cee54c3..29fe2b7 100644 --- a/example/generate.py +++ b/example/generate.py @@ -4,7 +4,7 @@ import psycopg from testcontainers.postgres import PostgresContainer -from iron_sql.codegen import generate_sql_package +from iron_sql.codegen import generate_sql_module def init_db(dsn: str, schema_path: Path): @@ -12,15 +12,15 @@ def init_db(dsn: str, schema_path: Path): conn.execute(schema_path.read_text(encoding="utf-8")) # pyright: ignore[reportCallIssue, reportArgumentType] -def generate_db_package(dsn: str, schema_path: Path, src_path: Path) -> bool: +def generate_db_module(dsn: str, schema_path: Path, src_path: Path) -> bool: # For example.config:DSN os.environ["DATABASE_URL"] = dsn - return generate_sql_package( + return generate_sql_module( schema_path=schema_path, - package_full_name="example.db.mydb", - dsn_import="example.config:DSN", - pool_options_import="example.config:POOL_OPTIONS", + module_full_name="example.db.mydb", + dsn_expr="example.config:DSN", + pool_options_expr="example.config:POOL_OPTIONS", src_path=src_path, json_model_overrides={ "projects.settings": "example.models:ProjectSettings", @@ -37,5 +37,5 @@ def generate_db_package(dsn: str, schema_path: Path, src_path: Path) -> bool: with PostgresContainer("postgres:17-alpine") as postgres: dsn = postgres.get_connection_url(driver=None) init_db(dsn, schema_path) - changed = generate_db_package(dsn, schema_path, src_path) - print("Updated SQL package:", changed) + changed = generate_db_module(dsn, schema_path, src_path) + print("Updated SQL module:", changed) diff --git a/src/iron_sql/codegen/__init__.py b/src/iron_sql/codegen/__init__.py index fed3b7a..b38979d 100644 --- a/src/iron_sql/codegen/__init__.py +++ b/src/iron_sql/codegen/__init__.py @@ -1,7 +1,7 @@ from iron_sql.codegen.generator import UnknownSQLTypeWarning -from iron_sql.codegen.generator import generate_sql_package +from iron_sql.codegen.generator import generate_sql_module __all__ = [ "UnknownSQLTypeWarning", - "generate_sql_package", + "generate_sql_module", ] diff --git a/src/iron_sql/codegen/generator.py b/src/iron_sql/codegen/generator.py index b0049c3..0532686 100644 --- a/src/iron_sql/codegen/generator.py +++ b/src/iron_sql/codegen/generator.py @@ -84,10 +84,7 @@ def parse(cls, value: str) -> "ModuleExprRef": raise ValueError(msg) match = re.match(r"[A-Za-z_][A-Za-z0-9_]*", module_expr) if match is None: - msg = ( - "module expression must start with identifier, " - f"got: {module_expr!r}" - ) + msg = f"module expression must start with identifier, got: {module_expr!r}" raise ValueError(msg) return cls(module_name=module_name, module_expr=module_expr) @@ -155,11 +152,11 @@ def evaluate[T](self, *, expected_type: type[T]) -> T: @dataclass(kw_only=True, frozen=True) class TypeResolver: catalog: Catalog - package_name: str + module_name: str to_pascal_fn: Callable[[str], str] to_snake_fn: Callable[[str], str] type_overrides: dict[str, str] - json_col_overrides: dict[tuple[str, str], str] + json_column_type_overrides: dict[tuple[str, str], str] def column_spec(self, column: Column) -> ColumnSpec: _, py_type, json_type = self._resolve(column) @@ -188,7 +185,10 @@ def _resolve(self, column: Column) -> tuple[str, str, str | None]: json_type = None if column.table is not None: col_name = column.original_name or column.name - json_type = self.json_col_overrides.get((column.table.name, col_name)) + json_type = self.json_column_type_overrides.get(( + column.table.name, + col_name, + )) if json_type: py_type = json_type @@ -198,8 +198,8 @@ def _resolve(self, column: Column) -> tuple[str, str, str | None]: py_type = _SQL_TYPE_MAP[db_type] elif self.catalog.schema_by_ref(column.type).has_enum(db_type): py_type = ( - self.to_pascal_fn(f"{self.package_name}_{self.to_snake_fn(db_type)}") - if self.package_name + self.to_pascal_fn(f"{self.module_name}_{self.to_snake_fn(db_type)}") + if self.module_name else "str" ) else: @@ -234,14 +234,14 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]: def map_sqlc_error( error: str, block_starts: list[tuple[int, str]], - all_locations: dict[str, list[str]], + query_locations_by_name: dict[str, list[str]], ) -> str: def replace(m: re.Match[str]) -> str: line = int(m.group(1)) name = next((n for start, n in reversed(block_starts) if start <= line), None) if name is None: return m.group(0) - locations = all_locations.get(name) + locations = query_locations_by_name.get(name) if not locations: return m.group(0) return f"{', '.join(locations)}:" @@ -249,12 +249,12 @@ def replace(m: re.Match[str]) -> str: return re.sub(r"queries\.sql:(\d+)(?::\d+)?:", replace, error) -def generate_sql_package( # noqa: PLR0913, PLR0914 +def generate_sql_module( # noqa: PLR0913, PLR0914 *, schema_path: Path, - package_full_name: str, - dsn_import: str, - pool_options_import: str | None = None, + module_full_name: str, + dsn_expr: str, + pool_options_expr: str | None = None, application_name: str | None = None, type_overrides: dict[str, str] | None = None, json_model_overrides: dict[str, str] | None = None, @@ -264,16 +264,16 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 src_path: Path = Path(), tempdir_path: Path | None = None, ) -> bool: - package_name = package_full_name.rsplit(".", maxsplit=1)[-1] - sql_fn_name = f"{package_name}_sql" + module_name = module_full_name.rsplit(".", maxsplit=1)[-1] + sql_fn_name = f"{module_name}_sql" - queries, all_locations = collect_queries(src_path, sql_fn_name) + queries, query_locations_by_name = collect_queries(src_path, sql_fn_name) - dsn_ref = ModuleExprRef.parse(dsn_import) + dsn_ref = ModuleExprRef.parse(dsn_expr) dsn = dsn_ref.evaluate(expected_type=str) pool_options_ref = ( - ModuleExprRef.parse(pool_options_import) - if pool_options_import is not None + ModuleExprRef.parse(pool_options_expr) + if pool_options_expr is not None else None ) if pool_options_ref is not None: @@ -281,31 +281,31 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 sqlc_res, block_starts = run_sqlc( src_path / schema_path, - [(q.name, q.stmt) for q in queries], + [(q.name, q.sql) for q in queries], dsn=dsn, debug_path=debug_path, tempdir_path=tempdir_path, ) if sqlc_res.error: - mapped = map_sqlc_error(sqlc_res.error, block_starts, all_locations) + mapped = map_sqlc_error(sqlc_res.error, block_starts, query_locations_by_name) logger.error(f"Error running SQLC:\n{mapped}") return False - json_import_block, json_col_overrides = resolve_json_model_overrides( + json_import_block, json_column_type_overrides = resolve_json_model_overrides( json_model_overrides or {}, sqlc_res.catalog ) resolver = TypeResolver( catalog=sqlc_res.catalog, - package_name=package_name, + module_name=module_name, to_pascal_fn=to_pascal_fn, to_snake_fn=to_snake_fn, type_overrides=type_overrides or {}, - json_col_overrides=json_col_overrides, + json_column_type_overrides=json_column_type_overrides, ) - ordered_entities, result_types = map_entities( + ordered_entities, query_result_types = build_entities( sqlc_res.queries, sqlc_res.used_schemas(), queries, @@ -317,27 +317,27 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 used_enums = collect_used_enums(sqlc_res) enums = sorted( - render_enum_class(e, package_name, to_pascal_fn, to_snake_fn) + render_enum_class(e, module_name, to_pascal_fn, to_snake_fn) for schema in sqlc_res.catalog.schemas for e in schema.enums if (schema.name, e.name) in used_enums ) query_classes = render_query_classes( - sqlc_res.queries, queries, resolver, result_types, all_locations + sqlc_res.queries, queries, resolver, query_result_types, query_locations_by_name ) query_overloads = [ - render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries + render_query_overload(sql_fn_name, q.name, q.sql, q.row_type) for q in queries ] - query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries] + query_dict_entries = [render_query_dict_entry(q.name, q.sql) for q in queries] - target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py" + target_module_path = src_path / f"{module_full_name.replace('.', '/')}.py" - new_content = render_package( + new_content = render_module( dsn_ref, - package_name, + module_name, sql_fn_name, entities, enums, @@ -348,9 +348,9 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 json_import_block, pool_options_ref, ) - changed = write_if_changed(target_package_path, new_content + "\n") + changed = write_if_changed(target_module_path, new_content + "\n") if changed: - logger.info(f"Generated SQL package {package_full_name}") + logger.info(f"Generated SQL module {module_full_name}") return changed @@ -358,23 +358,23 @@ def collect_queries( src_path: Path, sql_fn_name: str ) -> tuple[list["CodeQuery"], defaultdict[str, list[str]]]: raw = list(find_all_queries(src_path, sql_fn_name)) - validate_stmt_has_single_row_type(raw) - all_locations: defaultdict[str, list[str]] = defaultdict(list) + validate_sql_has_single_row_type(raw) + query_locations_by_name: defaultdict[str, list[str]] = defaultdict(list) first_occurrence: dict[str, CodeQuery] = {} for q in raw: - all_locations[q.name].append(q.location) + query_locations_by_name[q.name].append(q.location) if q.name not in first_occurrence: first_occurrence[q.name] = q queries = sorted(first_occurrence.values(), key=lambda q: (q.file, q.lineno)) - return queries, all_locations + return queries, query_locations_by_name def render_query_classes( sqlc_queries: tuple[Query, ...], queries: list["CodeQuery"], resolver: TypeResolver, - result_types: dict[str, str], - all_locations: defaultdict[str, list[str]], + query_result_types: dict[str, str], + query_locations_by_name: defaultdict[str, list[str]], ) -> list[str]: query_order = {q.name: i for i, q in enumerate(queries)} return [ @@ -389,14 +389,14 @@ def render_query_classes( ) for p in q.params ], - result_types[q.name], + query_result_types[q.name], len(q.columns), ( resolver.column_spec(q.columns[0]).json_type if len(q.columns) == 1 else None ), - all_locations[q.name], + query_locations_by_name[q.name], ) for q in sorted(sqlc_queries, key=lambda q: query_order[q.name]) ] @@ -459,9 +459,9 @@ def resolve_json_model_overrides( return import_block, col_overrides -def render_package( # noqa: PLR0913, PLR0917 +def render_module( # noqa: PLR0913, PLR0917 dsn_ref: ModuleExprRef, - package_name: str, + module_name: str, sql_fn_name: str, entities: list[str], enums: list[str], @@ -475,7 +475,7 @@ def render_package( # noqa: PLR0913, PLR0917 imports = [f"from {dsn_ref.module_name} import {dsn_ref.import_name}"] pool_args = [ dsn_ref.module_expr, - f'name="{package_name}"', + f'name="{module_name}"', f"application_name={application_name!r}", ] if pool_options_ref is not None: @@ -524,39 +524,39 @@ def render_package( # noqa: PLR0913, PLR0917 {imports_block} -{package_name.upper()}_POOL = runtime.ConnectionPool( +{module_name.upper()}_POOL = runtime.ConnectionPool( {pool_args_str}, ) -_{package_name}_connection = ContextVar[psycopg.AsyncConnection | None]( - "_{package_name}_connection", +_{module_name}_connection = ContextVar[psycopg.AsyncConnection | None]( + "_{module_name}_connection", default=None, ) @asynccontextmanager -async def {package_name}_connection() -> AsyncIterator[psycopg.AsyncConnection]: - async with {package_name.upper()}_POOL.connection_in_context(_{package_name}_connection) as conn: +async def {module_name}_connection() -> AsyncIterator[psycopg.AsyncConnection]: + async with {module_name.upper()}_POOL.connection_in_context(_{module_name}_connection) as conn: yield conn @asynccontextmanager -async def {package_name}_transaction() -> AsyncIterator[None]: - async with {package_name}_connection() as conn, conn.transaction(): +async def {module_name}_transaction() -> AsyncIterator[None]: + async with {module_name}_connection() as conn, conn.transaction(): yield @asynccontextmanager -async def {package_name}_listen_session( +async def {module_name}_listen_session( channel: str, ) -> AsyncIterator[AsyncGenerator[str]]: - async with {package_name.upper()}_POOL.connection() as conn: + async with {module_name.upper()}_POOL.connection() as conn: async with runtime.listen(conn, channel) as payloads: yield payloads -async def {package_name}_notify(channel: str, payload: str = "") -> None: - async with {package_name}_connection() as conn: +async def {module_name}_notify(channel: str, payload: str = "") -> None: + async with {module_name}_connection() as conn: await runtime.notify(conn, channel, payload) @@ -567,7 +567,7 @@ async def {package_name}_notify(channel: str, payload: str = "") -> None: class Query[T](runtime.Query[T]): - _connection_factory = staticmethod({package_name}_connection) + _connection_factory = staticmethod({module_name}_connection) {"\n\n\n".join(query_classes)} @@ -580,13 +580,13 @@ class Query[T](runtime.Query[T]): {"\n".join(query_overloads)} @overload -def {sql_fn_name}(stmt: str) -> Query: ... +def {sql_fn_name}(sql: str) -> Query: ... -def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query: - if stmt in _QUERIES: - return _QUERIES[stmt]() - msg = f"Unknown statement: {{stmt!r}}" +def {sql_fn_name}(sql: str, row_type: str | None = None) -> Query: + if sql in _QUERIES: + return _QUERIES[sql]() + msg = f"Unknown statement: {{sql!r}}" raise KeyError(msg) """.strip() # noqa: E501 @@ -594,11 +594,11 @@ def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query: def render_enum_class( enum: Enum, - package_name: str, + module_name: str, to_pascal_fn: Callable[[str], str], to_snake_fn: Callable[[str], str], ) -> str: - class_name = to_pascal_fn(f"{package_name}_{to_snake_fn(enum.name)}") + class_name = to_pascal_fn(f"{module_name}_{to_snake_fn(enum.name)}") members = [] seen_names: dict[str, int] = {} @@ -654,7 +654,7 @@ def deduplicate_params(params: list[ParamSpec]) -> list[ParamSpec]: def render_query_class( query_name: str, - stmt: str, + sql: str, query_params: list[ParamSpec], result: str, columns_num: int, @@ -729,7 +729,7 @@ async def execute({", ".join(query_fn_params)}) -> None: class {query_name}(Query[{result}]): # See: {", ".join(locations)} - _stmt = psycopg.sql.SQL({stmt!r}) + _stmt = psycopg.sql.SQL({sql!r}) _row_factory = staticmethod({row_factory}) {indent_block(methods, " ")} @@ -738,7 +738,7 @@ class {query_name}(Query[{result}]): def render_query_overload( - sql_fn_name: str, query_name: str, stmt: str, row_type: str | None + sql_fn_name: str, query_name: str, sql: str, row_type: str | None ) -> str: result_arg = "" if row_type: @@ -747,25 +747,25 @@ def render_query_overload( return f""" @overload -def {sql_fn_name}(stmt: Literal[{stmt!r}]{result_arg}) -> {query_name}: ... +def {sql_fn_name}(sql: Literal[{sql!r}]{result_arg}) -> {query_name}: ... """.strip() -def render_query_dict_entry(query_name: str, stmt: str) -> str: - return f"{stmt!r}: {query_name}" +def render_query_dict_entry(query_name: str, sql: str) -> str: + return f"{sql!r}: {query_name}" @dataclass(kw_only=True, frozen=True) class CodeQuery: - stmt: str + sql: str row_type: str | None file: Path lineno: int @property def name(self) -> str: - md5_hash = hashlib.md5(self.stmt.encode(), usedforsecurity=False).hexdigest() + md5_hash = hashlib.md5(self.sql.encode(), usedforsecurity=False).hexdigest() return f"Query_{md5_hash}{'_' + self.row_type if self.row_type else ''}" @property @@ -776,17 +776,17 @@ def location(self) -> str: @dataclass(kw_only=True, frozen=True) class SQLEntity: resolver: TypeResolver - set_name: str | None + explicit_name: str | None table_name: str | None columns: tuple[Column, ...] @property def name(self) -> str: - if self.set_name: - return self.set_name + if self.explicit_name: + return self.explicit_name if self.table_name: return self.resolver.to_pascal_fn( - f"{self.resolver.package_name}_{inflection.singularize(self.table_name)}" + f"{self.resolver.module_name}_{inflection.singularize(self.table_name)}" ) hash_base = repr(self.column_specs) md5_hash = hashlib.md5(hash_base.encode(), usedforsecurity=False).hexdigest() @@ -797,7 +797,7 @@ def column_specs(self) -> tuple[ColumnSpec, ...]: return tuple(self.resolver.column_spec(c) for c in self.columns) -def map_entities( +def build_entities( queries_from_sqlc: tuple[Query, ...], used_schemas: tuple[str, ...], queries_from_code: list[CodeQuery], @@ -808,7 +808,7 @@ def map_entities( table_entities = [ SQLEntity( resolver=resolver, - set_name=None, + explicit_name=None, table_name=t.rel.name, columns=t.columns, ) @@ -828,7 +828,7 @@ def map_entities( query_result_entities = { q.name: SQLEntity( resolver=resolver, - set_name=row_types[q.name], + explicit_name=row_types[q.name], table_name=None, columns=q.columns, ) @@ -845,17 +845,17 @@ def map_entities( key=lambda e: (e.table_name is None, e.table_name or ""), ) - result_types = {} + query_result_types = {} for q in queries_from_sqlc: if len(q.columns) == 0: - result_types[q.name] = "None" + query_result_types[q.name] = "None" elif len(q.columns) == 1: - result_types[q.name] = resolver.column_spec(q.columns[0]).py_type + query_result_types[q.name] = resolver.column_spec(q.columns[0]).py_type else: column_specs = query_result_entities[q.name].column_specs - result_types[q.name] = unique_entities[column_specs].name + query_result_types[q.name] = unique_entities[column_specs].name - return ordered_entities, result_types + return ordered_entities, query_result_types def find_fn_calls( @@ -882,11 +882,11 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]: for file, lineno, node in find_fn_calls(src_path, sql_fn_name): relative_path = file.relative_to(src_path) - stmt_arg = node.args[0] + sql_arg = node.args[0] if ( len(node.args) != 1 - or not isinstance(stmt_arg, ast.Constant) - or not isinstance(stmt_arg.value, str) + or not isinstance(sql_arg, ast.Constant) + or not isinstance(sql_arg.value, str) ): msg = ( f"Invalid positional arguments for {sql_fn_name} " @@ -895,7 +895,7 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]: ) raise TypeError(msg) - stmt = stmt_arg.value + sql = sql_arg.value row_type = None for kw in node.keywords: @@ -912,18 +912,18 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]: break yield CodeQuery( - stmt=stmt, + sql=sql, row_type=row_type, file=relative_path, lineno=lineno, ) -def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None: - first_by_stmt: dict[str, CodeQuery] = {} +def validate_sql_has_single_row_type(queries: list[CodeQuery]) -> None: + first_by_sql: dict[str, CodeQuery] = {} for query in queries: - if query.stmt in first_by_stmt: - first = first_by_stmt[query.stmt] + if query.sql in first_by_sql: + first = first_by_sql[query.sql] if query.row_type != first.row_type: msg = ( f"row_type conflict: {first.location} has {first.row_type!r}," @@ -931,4 +931,4 @@ def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None: ) raise ValueError(msg) else: - first_by_stmt[query.stmt] = query + first_by_sql[query.sql] = query diff --git a/src/iron_sql/codegen/sqlc.py b/src/iron_sql/codegen/sqlc.py index 9add016..378895c 100644 --- a/src/iron_sql/codegen/sqlc.py +++ b/src/iron_sql/codegen/sqlc.py @@ -159,8 +159,8 @@ def run_sqlc( block_starts: list[tuple[int, str]] = [] blocks: list[str] = [] current_line = 1 - for name, stmt in queries: - block = f"-- name: {name} :exec\n{preprocess_sql(stmt)};" + for name, sql in queries: + block = f"-- name: {name} :exec\n{preprocess_sql(sql)};" block_starts.append((current_line, name)) current_line += block.count("\n") + 2 blocks.append(block) @@ -214,6 +214,6 @@ def run_sqlc( ), block_starts -def preprocess_sql(stmt: str) -> str: - stmt = re.sub(r"@(\w+)\?", r"sqlc.narg('\1')", stmt) - return textwrap.dedent(stmt).strip() +def preprocess_sql(sql: str) -> str: + sql = re.sub(r"@(\w+)\?", r"sqlc.narg('\1')", sql) + return textwrap.dedent(sql).strip() diff --git a/tests/conftest.py b/tests/conftest.py index 8f13f43..b9b4f58 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from psycopg import sql from testcontainers.postgres import PostgresContainer -from iron_sql.codegen import generate_sql_package +from iron_sql.codegen import generate_sql_module from iron_sql.runtime import ConnectionPool # ============================================================================= @@ -146,7 +146,7 @@ def __init__( self.dsn = dsn self.test_name = test_name self.schema_path = schema_path - self.pkg_name = f"testapp_{test_name}.testdb" + self.module_full_name = f"testapp_{test_name}.testdb" self.src_path = root / "src" self.app_pkg = f"testapp_{test_name}" self.app_dir = self.src_path / self.app_pkg @@ -212,10 +212,10 @@ def generate_no_import( if str(self.src_path) not in sys.path: sys.path.insert(0, str(self.src_path)) - return generate_sql_package( + return generate_sql_module( schema_path=Path("schema.sql"), - package_full_name=self.pkg_name, - dsn_import=f"{self.app_pkg}.config:DSN", + module_full_name=self.module_full_name, + dsn_expr=f"{self.app_pkg}.config:DSN", src_path=self.src_path, tempdir_path=self.src_path, type_overrides=type_overrides, @@ -234,9 +234,9 @@ def generate( ) importlib.invalidate_caches() - sys.modules.pop(self.pkg_name, None) + sys.modules.pop(self.module_full_name, None) - mod = importlib.import_module(self.pkg_name) + mod = importlib.import_module(self.module_full_name) self.generated_modules.append(mod) return mod diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 71aeef6..9f68311 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -7,7 +7,7 @@ import pytest -from iron_sql.codegen import generate_sql_package +from iron_sql.codegen import generate_sql_module from iron_sql.codegen.generator import ModuleExprRef from iron_sql.codegen.sqlc import run_sqlc from tests.conftest import ProjectBuilder @@ -27,7 +27,7 @@ def testdb_sql(q: str, **kwargs: Any) -> Any: ... test_project.generate_no_import() generated = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ).read_text() see_lines = [ @@ -238,7 +238,7 @@ def get_dsn() -> str: assert expr_ref.evaluate(expected_type=str) == test_project.dsn -def test_dsn_import_with_function_call(test_project: ProjectBuilder) -> None: +def test_dsn_expr_with_function_call(test_project: ProjectBuilder) -> None: (test_project.app_dir / "config.py").write_text( f""" class Config: @@ -257,22 +257,22 @@ def get_dsn(self) -> str: if str(test_project.src_path) not in sys.path: sys.path.insert(0, str(test_project.src_path)) - generate_sql_package( + generate_sql_module( schema_path=Path("schema.sql"), - package_full_name=test_project.pkg_name, - dsn_import=f"{test_project.app_pkg}.config:CONFIG.get_dsn()", + module_full_name=test_project.module_full_name, + dsn_expr=f"{test_project.app_pkg}.config:CONFIG.get_dsn()", src_path=test_project.src_path, tempdir_path=test_project.src_path, ) generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) generated = generated_path.read_text() assert "CONFIG.get_dsn()" in generated -def test_dsn_import_with_factory_call_generates_valid_python( +def test_dsn_expr_with_factory_call_generates_valid_python( test_project: ProjectBuilder, ) -> None: (test_project.app_dir / "config.py").write_text( @@ -289,16 +289,16 @@ def get_dsn() -> str: if str(test_project.src_path) not in sys.path: sys.path.insert(0, str(test_project.src_path)) - generate_sql_package( + generate_sql_module( schema_path=Path("schema.sql"), - package_full_name=test_project.pkg_name, - dsn_import=f"{test_project.app_pkg}.config:get_dsn()", + module_full_name=test_project.module_full_name, + dsn_expr=f"{test_project.app_pkg}.config:get_dsn()", src_path=test_project.src_path, tempdir_path=test_project.src_path, ) generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) generated = generated_path.read_text() assert f"from {test_project.app_pkg}.config import get_dsn" in generated @@ -306,7 +306,7 @@ def get_dsn() -> str: parse(generated) -def test_pool_options_import(test_project: ProjectBuilder) -> None: +def test_pool_options_expr(test_project: ProjectBuilder) -> None: config = ( f'DSN = "{test_project.dsn}"\nPOOL_OPTIONS = {{"min_size": 1, "max_size": 5}}\n' ) @@ -317,24 +317,24 @@ def test_pool_options_import(test_project: ProjectBuilder) -> None: if str(test_project.src_path) not in sys.path: sys.path.insert(0, str(test_project.src_path)) - generate_sql_package( + generate_sql_module( schema_path=Path("schema.sql"), - package_full_name=test_project.pkg_name, - dsn_import=f"{test_project.app_pkg}.config:DSN", - pool_options_import=f"{test_project.app_pkg}.config:POOL_OPTIONS", + module_full_name=test_project.module_full_name, + dsn_expr=f"{test_project.app_pkg}.config:DSN", + pool_options_expr=f"{test_project.app_pkg}.config:POOL_OPTIONS", src_path=test_project.src_path, tempdir_path=test_project.src_path, ) generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) generated = generated_path.read_text() assert f"from {test_project.app_pkg}.config import POOL_OPTIONS" in generated assert "pool_options=POOL_OPTIONS" in generated -def test_pool_options_import_factory_generates_valid_python( +def test_pool_options_expr_factory_generates_valid_python( test_project: ProjectBuilder, ) -> None: config = f"""DSN = "{test_project.dsn}" @@ -349,17 +349,17 @@ def get_pool_options() -> dict[str, object]: if str(test_project.src_path) not in sys.path: sys.path.insert(0, str(test_project.src_path)) - generate_sql_package( + generate_sql_module( schema_path=Path("schema.sql"), - package_full_name=test_project.pkg_name, - dsn_import=f"{test_project.app_pkg}.config:DSN", - pool_options_import=f"{test_project.app_pkg}.config:get_pool_options()", + module_full_name=test_project.module_full_name, + dsn_expr=f"{test_project.app_pkg}.config:DSN", + pool_options_expr=f"{test_project.app_pkg}.config:get_pool_options()", src_path=test_project.src_path, tempdir_path=test_project.src_path, ) generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) generated = generated_path.read_text() assert f"from {test_project.app_pkg}.config import get_pool_options" in generated @@ -367,7 +367,7 @@ def get_pool_options() -> dict[str, object]: parse(generated) -def test_pool_options_import_invalid_fails_during_generation( +def test_pool_options_expr_invalid_fails_during_generation( test_project: ProjectBuilder, ) -> None: config = f'DSN = "{test_project.dsn}"\n' @@ -379,27 +379,27 @@ def test_pool_options_import_invalid_fails_during_generation( sys.path.insert(0, str(test_project.src_path)) with pytest.raises(NameError, match="MISSING_POOL_OPTIONS"): - generate_sql_package( + generate_sql_module( schema_path=Path("schema.sql"), - package_full_name=test_project.pkg_name, - dsn_import=f"{test_project.app_pkg}.config:DSN", - pool_options_import=f"{test_project.app_pkg}.config:MISSING_POOL_OPTIONS", + module_full_name=test_project.module_full_name, + dsn_expr=f"{test_project.app_pkg}.config:DSN", + pool_options_expr=f"{test_project.app_pkg}.config:MISSING_POOL_OPTIONS", src_path=test_project.src_path, tempdir_path=test_project.src_path, ) generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) assert not generated_path.exists() -def test_pool_options_import_not_set(test_project: ProjectBuilder) -> None: +def test_pool_options_expr_not_set(test_project: ProjectBuilder) -> None: test_project.add_query("q", "SELECT 1 as value") test_project.generate_no_import() generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) generated = generated_path.read_text() assert "**" not in generated diff --git a/tests/test_json_model_overrides.py b/tests/test_json_model_overrides.py index 222ee71..c1e4aaa 100644 --- a/tests/test_json_model_overrides.py +++ b/tests/test_json_model_overrides.py @@ -289,7 +289,7 @@ class Payload(BaseModel): ) generated_path = ( - test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + test_project.src_path / f"{test_project.module_full_name.replace('.', '/')}.py" ) generated = generated_path.read_text(encoding="utf-8")