From d756751ec191586e33e56e3726a3b264c7208ff2 Mon Sep 17 00:00:00 2001 From: Daniil Khristoliubov Date: Sat, 23 May 2026 15:51:56 +0300 Subject: [PATCH] feat: support schema-qualified table names for multi-schema databases --- README.md | 41 ++++ internal/codegen/constants/generator.go | 35 ++- internal/codegen/constants/render.go | 7 +- internal/codegen/constants/render_test.go | 73 ++++++ internal/codegen/crud/generator.go | 38 ++- internal/codegen/crud/generator_test.go | 69 +++++- internal/codegen/models/generator.go | 67 ++++-- internal/codegen/models/generator_pg_test.go | 218 ++++++++++++++++++ internal/codegen/orchestrator.go | 10 +- internal/codegen/orchestrator_test.go | 155 +++++++++++++ internal/codegen/sqlc/config.go | 7 +- internal/codegen/sqlc/config_test.go | 77 +++++++ internal/config/v2.go | 42 ++++ internal/config/v2_test.go | 32 +++ internal/sqlparser/catalog/catalog.go | 1 + internal/sqlparser/postgresql.go | 3 +- internal/sqlparser/typemap/postgresql.go | 48 +++- internal/sqlparser/typemap/postgresql_test.go | 47 ++++ internal/sqlparser/typemap/typemap.go | 1 + 19 files changed, 924 insertions(+), 47 deletions(-) create mode 100644 internal/codegen/constants/render_test.go create mode 100644 internal/codegen/sqlc/config_test.go diff --git a/README.md b/README.md index 4447e20..f2b9531 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,47 @@ tables: find: {} # → auto adds WHERE deleted_at IS NULL ``` +### Schema-qualified table names + +For tables in non-default schemas, use dot notation in the table key. The migration files must include `CREATE SCHEMA` before creating objects in that schema: + +```sql +-- sql/migrations/001_init.sql +CREATE SCHEMA shop; + +CREATE TABLE users ( id UUID PRIMARY KEY, email TEXT NOT NULL ); +CREATE TABLE shop.orders ( id UUID PRIMARY KEY, total NUMERIC NOT NULL ); + +CREATE TYPE shop.order_status AS ENUM ('pending', 'shipped', 'delivered'); +``` + +```yaml +tables: + users: # public schema (default) + primary_column: id + crud: + methods: + get: {} + create: { returning: "*" } + + shop.orders: # "shop" schema + primary_column: id + crud: + methods: + get: {} + create: { returning: "*" } +``` + +When a schema prefix is specified: + +- **SQL** uses the schema-qualified name: `INSERT INTO shop.orders`, `SELECT * FROM shop.orders` +- **Go identifiers** are prefixed with the schema name: `CreateShopOrder`, `GetShopOrder` +- **File paths** are prefixed with the schema name: `sql/queries/shop_orders/shop_orders_gen.sql` +- **Go types** (models, enums) are prefixed to match sqlc's naming convention: `shop.orders` → `ShopOrder`, `shop.order_status` → `ShopOrderStatus` +- **Cross-schema references** are resolved correctly — a column of type `shop.order_status` in any schema resolves to `ShopOrderStatus` + +This ensures no collisions when the same table name exists in different schemas. Tables in the default schema (`public` for PostgreSQL, `main` for SQLite) remain unprefixed. + ## Commands | Command | Description | diff --git a/internal/codegen/constants/generator.go b/internal/codegen/constants/generator.go index 47a0f40..bb7f7b4 100644 --- a/internal/codegen/constants/generator.go +++ b/internal/codegen/constants/generator.go @@ -6,6 +6,7 @@ import ( "go/format" "path/filepath" "sort" + "strings" "github.com/gobeam/stringy" "github.com/tkcrm/pgxgen/internal/config" @@ -55,6 +56,7 @@ func generateConstants( for _, tableName := range tableNames { tableConfig := schema.Tables[tableName] + parts := config.ParseTableKey(tableName) includeColumns := defaultIncludeColumns hasConstants := tableConfig.Constants != nil @@ -65,7 +67,8 @@ func generateConstants( continue } - outputDir := schema.ResolveOutputDir(tableName) + // Use Go name for directory resolution + outputDir := schema.ResolveOutputDir(parts.GoName()) if tableConfig.OutputDir != "" { outputDir = tableConfig.OutputDir } @@ -103,10 +106,16 @@ func generateConstants( } for _, entry := range entries { - tablePreffix := stringy.New(entry.tableName).CamelCase().UcFirst() + parts := config.ParseTableKey(entry.tableName) + goName := parts.GoName() + // Use Go name for identifier prefix (schema-prefixed when applicable) + tablePreffix := stringy.New(goName).CamelCase().UcFirst() + // Use schema-qualified name for constant value (SQL usage) + sqlName := parts.SQLTableName() params.Tables = append(params.Tables, ConstantsTableNamesParamsItem{ - NamePreffix: tablePreffix, - Name: entry.tableName, + NamePreffix: tablePreffix, + Name: sqlName, + BareTableName: goName, }) if entry.includeColumns { @@ -114,7 +123,7 @@ func generateConstants( for _, col := range columns { colPreffix := tablePreffix + stringy.New(col).CamelCase().UcFirst() params.ColumnNames = append(params.ColumnNames, ConstantsColumnNamesParamsItem{ - TableName: entry.tableName, + TableName: sqlName, NamePreffix: colPreffix, Name: col, }) @@ -144,9 +153,14 @@ func generateConstants( } func getTableColumns(cat *catalog.Catalog, tableName string) []string { + schemaName, bareName := splitSchemaTable(tableName) + for _, s := range cat.Schemas { + if schemaName != "" && s.Name != schemaName { + continue + } for _, t := range s.Tables { - if t.Name == tableName { + if t.Name == bareName { cols := make([]string, len(t.Columns)) for i, c := range t.Columns { cols[i] = c.Name @@ -157,3 +171,12 @@ func getTableColumns(cat *catalog.Catalog, tableName string) []string { } return nil } + +// splitSchemaTable splits "schema.table" into ("schema", "table"). +// For a bare "table" name it returns ("", "table"). +func splitSchemaTable(name string) (string, string) { + if idx := strings.LastIndex(name, "."); idx != -1 { + return name[:idx], name[idx+1:] + } + return "", name +} diff --git a/internal/codegen/constants/render.go b/internal/codegen/constants/render.go index 5157d03..cb40db9 100644 --- a/internal/codegen/constants/render.go +++ b/internal/codegen/constants/render.go @@ -10,8 +10,9 @@ import ( ) type ConstantsTableNamesParamsItem struct { - NamePreffix string - Name string + NamePreffix string + Name string + BareTableName string } type ConstantsColumnNamesParamsItem struct { @@ -94,7 +95,7 @@ func (s ColumnNames) Strings() []string { content.WriteString(")\n\n") for _, tableName := range p.Tables { - fmt.Fprintf(&content, "func %sColumnNames() ColumnNames {\n", utils.ToPascalCase(tableName.Name)) + fmt.Fprintf(&content, "func %sColumnNames() ColumnNames {\n", utils.ToPascalCase(tableName.BareTableName)) content.WriteString("return ColumnNames{\n") for _, item := range p.GetColumnsForTable(tableName.Name) { fmt.Fprintf(&content, "ColumnName%s,\n", item.NamePreffix) diff --git a/internal/codegen/constants/render_test.go b/internal/codegen/constants/render_test.go new file mode 100644 index 0000000..ed73e6b --- /dev/null +++ b/internal/codegen/constants/render_test.go @@ -0,0 +1,73 @@ +package constants + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRenderConstants_SameNamedTablesInDifferentSchemas(t *testing.T) { + params := ConstantsParams{ + Package: "store", + Tables: []ConstantsTableNamesParamsItem{ + {NamePreffix: "Orders", Name: "orders", BareTableName: "orders"}, + {NamePreffix: "ShopOrders", Name: "shop.orders", BareTableName: "shop_orders"}, + }, + ColumnNames: []ConstantsColumnNamesParamsItem{ + {TableName: "orders", NamePreffix: "OrdersId", Name: "id"}, + {TableName: "orders", NamePreffix: "OrdersTotal", Name: "total"}, + {TableName: "shop.orders", NamePreffix: "ShopOrdersId", Name: "id"}, + {TableName: "shop.orders", NamePreffix: "ShopOrdersAmount", Name: "amount"}, + }, + } + + var buf bytes.Buffer + err := RenderConstants(params, &buf) + require.NoError(t, err) + output := buf.String() + + // Distinct table name constants + assert.Contains(t, output, `TableNameOrders TableName = "orders"`) + assert.Contains(t, output, `TableNameShopOrders TableName = "shop.orders"`) + + // Distinct column name constants — no collision + assert.Contains(t, output, `ColumnNameOrdersId ColumnName = "id"`) + assert.Contains(t, output, `ColumnNameOrdersTotal ColumnName = "total"`) + assert.Contains(t, output, `ColumnNameShopOrdersId ColumnName = "id"`) + assert.Contains(t, output, `ColumnNameShopOrdersAmount ColumnName = "amount"`) + + // Distinct ColumnNames() functions + assert.Contains(t, output, "func OrdersColumnNames() ColumnNames {") + assert.Contains(t, output, "func ShopOrdersColumnNames() ColumnNames {") + + // Each function returns only its own columns + // OrdersColumnNames should contain OrdersId, OrdersTotal + // ShopOrdersColumnNames should contain ShopOrdersId, ShopOrdersAmount + assert.Contains(t, output, "ColumnNameOrdersId,\nColumnNameOrdersTotal,") + assert.Contains(t, output, "ColumnNameShopOrdersId,\nColumnNameShopOrdersAmount,") +} + +func TestRenderConstants_SingleSchemaUnchanged(t *testing.T) { + params := ConstantsParams{ + Package: "store", + Tables: []ConstantsTableNamesParamsItem{ + {NamePreffix: "Users", Name: "users", BareTableName: "users"}, + }, + ColumnNames: []ConstantsColumnNamesParamsItem{ + {TableName: "users", NamePreffix: "UsersId", Name: "id"}, + {TableName: "users", NamePreffix: "UsersName", Name: "name"}, + }, + } + + var buf bytes.Buffer + err := RenderConstants(params, &buf) + require.NoError(t, err) + output := buf.String() + + // Plain table — no schema prefix + assert.Contains(t, output, `TableNameUsers TableName = "users"`) + assert.Contains(t, output, `ColumnNameUsersId ColumnName = "id"`) + assert.Contains(t, output, "func UsersColumnNames() ColumnNames {") +} diff --git a/internal/codegen/crud/generator.go b/internal/codegen/crud/generator.go index a442382..ccfaf1d 100644 --- a/internal/codegen/crud/generator.go +++ b/internal/codegen/crud/generator.go @@ -46,8 +46,11 @@ func New(eng engine.Engine) *Generator { } // GenerateTable generates all CRUD SQL for a single table. +// sqlTableName is the schema-qualified name used in SQL (e.g. "shop.orders"), +// bareTableName is the plain name used for Go identifiers (e.g. "orders"). func (g *Generator) GenerateTable( - tableName string, + sqlTableName string, + bareTableName string, tableConfig config.TableConfig, defaultCrud *config.CrudDefaultsConfig, columns []string, @@ -72,15 +75,15 @@ func (g *Generator) GenerateTable( } data, templateName, err := g.buildTemplateData( - tableName, methodName, methodCfg, tableConfig, defaultCrud, columns, + sqlTableName, bareTableName, methodName, methodCfg, tableConfig, defaultCrud, columns, ) if err != nil { - return nil, fmt.Errorf("build template data for %s.%s: %w", tableName, methodName, err) + return nil, fmt.Errorf("build template data for %s.%s: %w", sqlTableName, methodName, err) } rendered, err := g.renderTemplate(templateName, data) if err != nil { - return nil, fmt.Errorf("render template %s for %s.%s: %w", templateName, tableName, methodName, err) + return nil, fmt.Errorf("render template %s for %s.%s: %w", templateName, sqlTableName, methodName, err) } buf.Write(rendered) @@ -96,21 +99,21 @@ func (g *Generator) GenerateTable( } func (g *Generator) buildTemplateData( - tableName, methodName string, + sqlTableName, bareTableName, methodName string, methodCfg *config.MethodConfig, tableConfig config.TableConfig, defaultCrud *config.CrudDefaultsConfig, allColumns []string, ) (*TemplateData, string, error) { data := &TemplateData{ - TableName: tableName, + TableName: sqlTableName, PrimaryColumn: tableConfig.PrimaryColumn, } - // Resolve method name + // Resolve method name using bare table name (no schema prefix in Go identifiers) data.MethodName = methodCfg.Name if data.MethodName == "" { - data.MethodName = g.defaultMethodName(methodName, tableName, defaultCrud) + data.MethodName = g.defaultMethodName(methodName, bareTableName, defaultCrud) } // Soft delete @@ -400,10 +403,18 @@ func cloneWhere(w map[string]config.WhereParamConfig) map[string]config.WherePar } // GetTableColumns extracts column names from a catalog table. +// tableName can be a bare name ("users") or schema-qualified ("shop.users"). +// When schema-qualified, it matches only the specific schema; when bare, +// it searches all schemas (backward-compatible behavior). func GetTableColumns(cat *catalog.Catalog, tableName string) []string { + schemaName, bareName := splitSchemaTable(tableName) + for _, schema := range cat.Schemas { + if schemaName != "" && schema.Name != schemaName { + continue + } for _, table := range schema.Tables { - if table.Name == tableName { + if table.Name == bareName { columns := make([]string, len(table.Columns)) for i, col := range table.Columns { columns[i] = col.Name @@ -414,3 +425,12 @@ func GetTableColumns(cat *catalog.Catalog, tableName string) []string { } return nil } + +// splitSchemaTable splits "schema.table" into ("schema", "table"). +// For a bare "table" name it returns ("", "table"). +func splitSchemaTable(name string) (string, string) { + if idx := strings.LastIndex(name, "."); idx != -1 { + return name[:idx], name[idx+1:] + } + return "", name +} diff --git a/internal/codegen/crud/generator_test.go b/internal/codegen/crud/generator_test.go index ff46b4a..133c8f2 100644 --- a/internal/codegen/crud/generator_test.go +++ b/internal/codegen/crud/generator_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tkcrm/pgxgen/internal/config" "github.com/tkcrm/pgxgen/internal/engine" + "github.com/tkcrm/pgxgen/internal/sqlparser/catalog" "gopkg.in/yaml.v3" ) @@ -23,7 +24,7 @@ func newTestGenerator(t *testing.T, engineName string) *Generator { func generateTable(t *testing.T, engineName, tableName string, tableConfig config.TableConfig, defaultCrud *config.CrudDefaultsConfig, columns []string) string { t.Helper() gen := newTestGenerator(t, engineName) - result, err := gen.GenerateTable(tableName, tableConfig, defaultCrud, columns) + result, err := gen.GenerateTable(tableName, tableName, tableConfig, defaultCrud, columns) require.NoError(t, err) return string(result) } @@ -104,7 +105,7 @@ func TestGenerateTable_SingleTrailingNewline(t *testing.T) { } gen := newTestGenerator(t, "postgresql") - data, err := gen.GenerateTable("users", tableConfig, nil, testColumns) + data, err := gen.GenerateTable("users", "users", tableConfig, nil, testColumns) require.NoError(t, err) require.NotEmpty(t, data) @@ -124,7 +125,7 @@ func TestBatchCreate_UnsupportedSQLite(t *testing.T) { } gen := newTestGenerator(t, "sqlite") - _, err := gen.GenerateTable("users", tableConfig, nil, testColumns) + _, err := gen.GenerateTable("users", "users", tableConfig, nil, testColumns) require.Error(t, err) assert.Contains(t, err.Error(), "batch_create is not supported for SQLite") } @@ -190,6 +191,68 @@ func TestBatchCreate_CustomName(t *testing.T) { assert.Contains(t, result, "-- name: InsertBulkUsers :copyfrom") } +func TestSchemaQualifiedTableName(t *testing.T) { + tableConfig := config.TableConfig{ + PrimaryColumn: "id", + Crud: &config.TableCrudConfig{ + Methods: map[string]*config.MethodConfig{ + "create": {Returning: "*"}, + "get": {}, + }, + }, + } + + gen := newTestGenerator(t, "postgresql") + // sqlTableName is schema-qualified, goName is schema-prefixed for Go identifiers + result, err := gen.GenerateTable("shop.orders", "shop_orders", tableConfig, nil, testColumns) + require.NoError(t, err) + sql := string(result) + + // SQL should use schema-qualified table name + assert.Contains(t, sql, "INSERT INTO shop.orders") + assert.Contains(t, sql, "SELECT * FROM shop.orders") + + // Go method names should use schema-prefixed name + assert.Contains(t, sql, "-- name: CreateShopOrder :one") + assert.Contains(t, sql, "-- name: GetShopOrder :one") +} + +func TestSchemaQualifiedGetTableColumns(t *testing.T) { + cat := &catalog.Catalog{ + DefaultSchema: "public", + Schemas: []*catalog.Schema{ + { + Name: "public", + Tables: []*catalog.Table{ + {Name: "users", Schema: "public", Columns: []*catalog.Column{ + {Name: "id"}, {Name: "name"}, + }}, + }, + }, + { + Name: "shop", + Tables: []*catalog.Table{ + {Name: "orders", Schema: "shop", Columns: []*catalog.Column{ + {Name: "id"}, {Name: "total"}, {Name: "user_id"}, + }}, + }, + }, + }, + } + + // Bare name — searches all schemas, finds first match + cols := GetTableColumns(cat, "users") + assert.Equal(t, []string{"id", "name"}, cols) + + // Schema-qualified — matches only the specific schema + cols = GetTableColumns(cat, "shop.orders") + assert.Equal(t, []string{"id", "total", "user_id"}, cols) + + // Schema-qualified — wrong schema returns nil + cols = GetTableColumns(cat, "public.orders") + assert.Nil(t, cols) +} + func TestBatchCreate_ExcludeTableName(t *testing.T) { tableConfig := config.TableConfig{ Crud: &config.TableCrudConfig{ diff --git a/internal/codegen/models/generator.go b/internal/codegen/models/generator.go index 97fd268..f457fb0 100644 --- a/internal/codegen/models/generator.go +++ b/internal/codegen/models/generator.go @@ -111,6 +111,7 @@ func GenerateFromV2Config( mapOpts := typemap.Options{ SqlPackage: sqlPackage, EmitPointersForNull: emitPointers, + DefaultSchema: cat.DefaultSchema, } var sqlcOverrides *config.SqlcOverridesConfig @@ -174,11 +175,24 @@ func renderModelsRaw( sqlcOverrides *config.SqlcOverridesConfig, sqlcDefaults *config.SqlcDefaultsConfig, ) ([]byte, error) { - // Collect all enums, tables, and views across schemas + // Collect all enums, tables, and views across schemas. + // For non-default schemas, prefix names with schema to match sqlc's naming + // (e.g. shop.shops → "shop_shops" → Go type "ShopShop"). + type enumEntry struct { + enum *catalog.Enum + goName string // schema-prefixed name for Go type generation + } + var allEnumEntries []enumEntry var allEnums []*catalog.Enum var allTables []*catalog.Table var allViews []*catalog.View for _, s := range cat.Schemas { + for _, e := range s.Enums { + allEnumEntries = append(allEnumEntries, enumEntry{ + enum: e, + goName: sqlcGoName(e.Name, s.Name, cat.DefaultSchema), + }) + } allEnums = append(allEnums, s.Enums...) allTables = append(allTables, s.Tables...) allViews = append(allViews, s.Views...) @@ -204,32 +218,40 @@ func renderModelsRaw( for _, name := range cfg.SkipEnums { skipEnums[name] = struct{}{} } + allEnumEntries = filterSlice(allEnumEntries, func(e enumEntry) bool { + _, skip := skipEnums[e.enum.Name] + return !skip + }) allEnums = filterSlice(allEnums, func(e *catalog.Enum) bool { _, skip := skipEnums[e.Name] return !skip }) } - sort.Slice(allEnums, func(i, j int) bool { - return allEnums[i].Name < allEnums[j].Name + sort.Slice(allEnumEntries, func(i, j int) bool { + return allEnumEntries[i].goName < allEnumEntries[j].goName }) sort.Slice(allTables, func(i, j int) bool { - return inflection.Singular(allTables[i].Name) < inflection.Singular(allTables[j].Name) + return inflection.Singular(sqlcGoName(allTables[i].Name, allTables[i].Schema, cat.DefaultSchema)) < + inflection.Singular(sqlcGoName(allTables[j].Name, allTables[j].Schema, cat.DefaultSchema)) }) sort.Slice(allViews, func(i, j int) bool { - return inflection.Singular(allViews[i].Name) < inflection.Singular(allViews[j].Name) + return inflection.Singular(sqlcGoName(allViews[i].Name, allViews[i].Schema, cat.DefaultSchema)) < + inflection.Singular(sqlcGoName(allViews[j].Name, allViews[j].Schema, cat.DefaultSchema)) }) // First pass: render body to collect which types are used var body bytes.Buffer - for _, enum := range allEnums { - renderEnum(&body, enum) + for _, entry := range allEnumEntries { + renderEnum(&body, entry.goName, entry.enum) } for _, table := range allTables { - renderStruct(&body, cfg, table.Name, table.Comment, table.Columns, mapper, allEnums, mapOpts, sqlcOverrides, sqlcDefaults) + goName := sqlcGoName(table.Name, table.Schema, cat.DefaultSchema) + renderStruct(&body, cfg, goName, table.Name, table.Comment, table.Columns, mapper, allEnums, mapOpts, sqlcOverrides, sqlcDefaults) } for _, view := range allViews { - renderStruct(&body, cfg, view.Name, view.Comment, view.Columns, mapper, allEnums, mapOpts, sqlcOverrides, sqlcDefaults) + goName := sqlcGoName(view.Name, view.Schema, cat.DefaultSchema) + renderStruct(&body, cfg, goName, view.Name, view.Comment, view.Columns, mapper, allEnums, mapOpts, sqlcOverrides, sqlcDefaults) } // Collect imports from the rendered body @@ -323,8 +345,8 @@ func extractImportPath(v interface{}) string { return "" } -func renderEnum(buf *bytes.Buffer, enum *catalog.Enum) { - typeName := toCamelCase(enum.Name) +func renderEnum(buf *bytes.Buffer, goName string, enum *catalog.Enum) { + typeName := toCamelCase(goName) fmt.Fprintf(buf, "type %s string\n\n", typeName) buf.WriteString("const (\n") @@ -353,7 +375,8 @@ func renderEnum(buf *bytes.Buffer, enum *catalog.Enum) { func renderStruct( buf *bytes.Buffer, cfg *config.ModelsConfig, - name string, + goName string, + tableName string, comment string, columns []*catalog.Column, mapper typemap.TypeMapper, @@ -362,8 +385,9 @@ func renderStruct( sqlcOverrides *config.SqlcOverridesConfig, sqlcDefaults *config.SqlcDefaultsConfig, ) { - // Singularize table/view name: todos → Todo, notes → Note - structName := toCamelCase(inflection.Singular(name)) + // Singularize and PascalCase the Go name (matches sqlc's naming convention). + // For non-default schemas, goName is "schema_table" (e.g. "shop_shops" → "ShopShop"). + structName := toCamelCase(inflection.Singular(goName)) if comment != "" { fmt.Fprintf(buf, "// %s %s\n", structName, comment) @@ -373,8 +397,8 @@ func renderStruct( for _, col := range columns { fieldName := toCamelCase(col.Name) - fieldType := resolveType(cfg, name, col, mapper, enums, mapOpts, sqlcOverrides) - tags := buildTags(cfg, name, col, sqlcOverrides, sqlcDefaults) + fieldType := resolveType(cfg, tableName, col, mapper, enums, mapOpts, sqlcOverrides) + tags := buildTags(cfg, tableName, col, sqlcOverrides, sqlcDefaults) tagStr := "" if len(tags) > 0 { @@ -551,6 +575,17 @@ func filterSlice[T any](s []T, keep func(T) bool) []T { return result } +// sqlcGoName returns the name that sqlc would use for a database object. +// For objects in the default schema (e.g. "public"), it returns the bare name. +// For objects in other schemas, it prepends the schema name with an underscore, +// matching sqlc's convention (e.g. schema "shop", table "shops" → "shop_shops"). +func sqlcGoName(name, schema, defaultSchema string) string { + if schema == "" || schema == defaultSchema { + return name + } + return schema + "_" + name +} + // toCamelCase converts snake_case to CamelCase with Go acronym handling. // Examples: user_id → UserID, http_url → HTTPURL, name → Name func toCamelCase(s string) string { diff --git a/internal/codegen/models/generator_pg_test.go b/internal/codegen/models/generator_pg_test.go index 95229f4..c89867f 100644 --- a/internal/codegen/models/generator_pg_test.go +++ b/internal/codegen/models/generator_pg_test.go @@ -146,6 +146,7 @@ func renderPg(t *testing.T, cfg *config.ModelsConfig, overrides *config.SqlcOver mapOpts := typemap.Options{ SqlPackage: sqlPkg, EmitPointersForNull: cfg.EmitPointersForNull, + DefaultSchema: cat.DefaultSchema, } code, err := renderModelsRaw(cfg, cat, mapper, mapOpts, overrides, defaults) require.NoError(t, err) @@ -313,6 +314,223 @@ func TestPg_ImportsContainOrb(t *testing.T) { assert.Contains(t, output, `"github.com/paulmach/orb"`) } +// --- sqlcGoName --- + +func TestSqlcGoName(t *testing.T) { + tests := []struct { + name, schema, defaultSchema, want string + }{ + {"users", "public", "public", "users"}, + {"users", "", "public", "users"}, + {"orders", "shop", "public", "shop_orders"}, + {"items", "main", "main", "items"}, + {"items", "other", "main", "other_items"}, + } + for _, tt := range tests { + got := sqlcGoName(tt.name, tt.schema, tt.defaultSchema) + if got != tt.want { + t.Errorf("sqlcGoName(%q, %q, %q) = %q, want %q", tt.name, tt.schema, tt.defaultSchema, got, tt.want) + } + } +} + +// --- Multi-schema model generation --- + +func TestPg_MultiSchemaEnumAndStruct(t *testing.T) { + cat := &catalog.Catalog{ + DefaultSchema: "public", + Schemas: []*catalog.Schema{ + { + Name: "public", + Enums: []*catalog.Enum{ + {Name: "user_status", Schema: "public", Values: []string{"active", "inactive"}}, + }, + Tables: []*catalog.Table{ + { + Name: "users", + Schema: "public", + Columns: []*catalog.Column{ + {Name: "id", Type: "uuid", NotNull: true}, + {Name: "status", Type: "user_status", NotNull: true}, + }, + }, + }, + }, + { + Name: "shop", + Enums: []*catalog.Enum{ + {Name: "order_status", Schema: "shop", Values: []string{"pending", "shipped"}}, + }, + Tables: []*catalog.Table{ + { + Name: "orders", + Schema: "shop", + Columns: []*catalog.Column{ + {Name: "id", Type: "uuid", NotNull: true}, + {Name: "status", Type: "shop.order_status", NotNull: true}, + }, + }, + }, + }, + }, + } + + mapper, err := typemap.NewTypeMapper("postgresql") + require.NoError(t, err) + mapOpts := typemap.Options{ + SqlPackage: "pgx/v5", + DefaultSchema: cat.DefaultSchema, + } + code, err := renderModelsRaw( + &config.ModelsConfig{PackageName: "models"}, + cat, mapper, mapOpts, nil, nil, + ) + require.NoError(t, err) + output := string(code) + + // Public schema enum — no prefix + assert.Contains(t, output, "type UserStatus string") + assert.Contains(t, output, `UserStatusActive UserStatus = "active"`) + + // Non-default schema enum — prefixed + assert.Contains(t, output, "type ShopOrderStatus string") + assert.Contains(t, output, `ShopOrderStatusPending ShopOrderStatus = "pending"`) + + // Public schema struct — no prefix + assert.Contains(t, output, "type User struct {") + assert.Contains(t, output, "\tStatus UserStatus") + + // Non-default schema struct — prefixed + assert.Contains(t, output, "type ShopOrder struct {") + assert.Contains(t, output, "\tStatus ShopOrderStatus") +} + +// --- Cross-schema enum reference --- + +func TestPg_CrossSchemaEnumReference(t *testing.T) { + // A table in "public" schema references an enum defined in "shop" schema + cat := &catalog.Catalog{ + DefaultSchema: "public", + Schemas: []*catalog.Schema{ + { + Name: "public", + Tables: []*catalog.Table{ + { + Name: "audit_logs", + Schema: "public", + Columns: []*catalog.Column{ + {Name: "id", Type: "uuid", NotNull: true}, + {Name: "order_status", Type: "shop.order_status", NotNull: true}, + {Name: "nullable_status", Type: "shop.order_status", NotNull: false}, + }, + }, + }, + }, + { + Name: "shop", + Enums: []*catalog.Enum{ + {Name: "order_status", Schema: "shop", Values: []string{"pending", "shipped", "delivered"}}, + }, + }, + }, + } + + mapper, err := typemap.NewTypeMapper("postgresql") + require.NoError(t, err) + mapOpts := typemap.Options{ + SqlPackage: "pgx/v5", + DefaultSchema: cat.DefaultSchema, + } + code, err := renderModelsRaw( + &config.ModelsConfig{PackageName: "models"}, + cat, mapper, mapOpts, nil, nil, + ) + require.NoError(t, err) + output := string(code) + + // Enum type should be schema-prefixed + assert.Contains(t, output, "type ShopOrderStatus string") + + // NOT NULL column referencing cross-schema enum + assert.Contains(t, output, "\tOrderStatus ShopOrderStatus") + + // Nullable column referencing cross-schema enum + assert.Contains(t, output, "\tNullableStatus NullShopOrderStatus") +} + +// --- Same-named enums in different schemas --- + +func TestPg_SameNamedEnumsInDifferentSchemas(t *testing.T) { + cat := &catalog.Catalog{ + DefaultSchema: "public", + Schemas: []*catalog.Schema{ + { + Name: "public", + Enums: []*catalog.Enum{ + {Name: "status", Schema: "public", Values: []string{"active", "inactive"}}, + }, + Tables: []*catalog.Table{ + { + Name: "users", + Schema: "public", + Columns: []*catalog.Column{ + {Name: "id", Type: "uuid", NotNull: true}, + {Name: "status", Type: "status", NotNull: true}, + }, + }, + }, + }, + { + Name: "shop", + Enums: []*catalog.Enum{ + {Name: "status", Schema: "shop", Values: []string{"pending", "shipped"}}, + }, + Tables: []*catalog.Table{ + { + Name: "orders", + Schema: "shop", + Columns: []*catalog.Column{ + {Name: "id", Type: "uuid", NotNull: true}, + {Name: "status", Type: "shop.status", NotNull: true}, + }, + }, + }, + }, + }, + } + + mapper, err := typemap.NewTypeMapper("postgresql") + require.NoError(t, err) + mapOpts := typemap.Options{ + SqlPackage: "pgx/v5", + DefaultSchema: cat.DefaultSchema, + } + code, err := renderModelsRaw( + &config.ModelsConfig{PackageName: "models"}, + cat, mapper, mapOpts, nil, nil, + ) + require.NoError(t, err) + output := string(code) + + // Public schema enum — no prefix + assert.Contains(t, output, "type Status string") + assert.Contains(t, output, `StatusActive Status = "active"`) + assert.Contains(t, output, `StatusInactive Status = "inactive"`) + + // Shop schema enum — prefixed to avoid collision + assert.Contains(t, output, "type ShopStatus string") + assert.Contains(t, output, `ShopStatusPending ShopStatus = "pending"`) + assert.Contains(t, output, `ShopStatusShipped ShopStatus = "shipped"`) + + // Public table uses bare enum type + assert.Contains(t, output, "type User struct {") + assert.Contains(t, output, "\tStatus Status") + + // Shop table uses prefixed enum type + assert.Contains(t, output, "type ShopOrder struct {") + assert.Contains(t, output, "\tStatus ShopStatus") +} + // --- View model generation for PostgreSQL --- func TestPg_ViewStructGenerated(t *testing.T) { diff --git a/internal/codegen/orchestrator.go b/internal/codegen/orchestrator.go index 6ff1473..0d0e01b 100644 --- a/internal/codegen/orchestrator.go +++ b/internal/codegen/orchestrator.go @@ -199,6 +199,7 @@ func (o *Orchestrator) generateCrud(schema *config.SchemaConfig, cat *catalog.Ca for _, tableName := range tableNames { tableConfig := schema.Tables[tableName] + parts := config.ParseTableKey(tableName) columns := crud.GetTableColumns(cat, tableName) if columns == nil { @@ -210,7 +211,9 @@ func (o *Orchestrator) generateCrud(schema *config.SchemaConfig, cat *catalog.Ca defaultCrud = schema.Defaults.Crud } - data, err := gen.GenerateTable(tableName, tableConfig, defaultCrud, columns) + // Pass schema-qualified name for SQL generation, Go name for identifiers + goName := parts.GoName() + data, err := gen.GenerateTable(parts.SQLTableName(), goName, tableConfig, defaultCrud, columns) if err != nil { return nil, fmt.Errorf("generate table %s: %w", tableName, err) } @@ -219,7 +222,8 @@ func (o *Orchestrator) generateCrud(schema *config.SchemaConfig, cat *catalog.Ca continue } - queriesDir := schema.ResolveQueriesDir(tableName) + // Use Go name for directory and file paths + queriesDir := schema.ResolveQueriesDir(goName) if tableConfig.QueriesDir != "" { queriesDir = tableConfig.QueriesDir } @@ -227,7 +231,7 @@ func (o *Orchestrator) generateCrud(schema *config.SchemaConfig, cat *catalog.Ca return nil, fmt.Errorf("no queries_dir resolved for table %s", tableName) } - outputPath := filepath.Join(o.resolvePath(queriesDir), tableName+"_gen.sql") + outputPath := filepath.Join(o.resolvePath(queriesDir), goName+"_gen.sql") results = append(results, PrepareResult(outputPath, data)) } diff --git a/internal/codegen/orchestrator_test.go b/internal/codegen/orchestrator_test.go index 2bd7b84..be1a990 100644 --- a/internal/codegen/orchestrator_test.go +++ b/internal/codegen/orchestrator_test.go @@ -107,6 +107,161 @@ func TestGenerate_CrudWrittenBeforeSqlc(t *testing.T) { assert.NotEmpty(t, goFiles, "sqlc must have generated at least one .go file in %s", sqlcOutDir) } +// TestGenerate_SchemaQualifiedTableNames verifies that when table keys +// contain a schema prefix (e.g. "myschema.users"), CRUD SQL uses the +// schema-qualified name in SQL statements and schema-prefixed Go name for paths. +func TestGenerate_SchemaQualifiedTableNames(t *testing.T) { + tmp := t.TempDir() + migrationsDir := filepath.Join(tmp, "sql", "migrations") + require.NoError(t, os.MkdirAll(migrationsDir, 0o755)) + + schemaSQL := `CREATE SCHEMA myschema; +CREATE TABLE myschema.users ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL, + name TEXT NOT NULL +);` + require.NoError(t, os.WriteFile(filepath.Join(migrationsDir, "001_init.sql"), []byte(schemaSQL), 0o644)) + + configPath := filepath.Join(tmp, "pgxgen.yaml") + cfg := &config.V2Config{ + Version: "2", + Schemas: []config.SchemaConfig{ + { + Name: "main", + Engine: "postgresql", + SchemaDir: "sql/migrations", + Defaults: &config.DefaultsConfig{ + QueriesDirPrefix: "sql/queries", + OutputDirPrefix: "internal/store/repos", + }, + Tables: map[string]config.TableConfig{ + "myschema.users": { + PrimaryColumn: "id", + Crud: &config.TableCrudConfig{ + Methods: map[string]*config.MethodConfig{ + "get": {}, + "create": {Returning: "*"}, + }, + }, + }, + }, + }, + }, + } + + orch := NewOrchestrator(logger.New(), cfg, configPath) + _, err := orch.Generate(context.Background(), GenerateOpts{Targets: []string{"crud"}}) + require.NoError(t, err) + + // File should use schema-prefixed Go name for path + sqlPath := filepath.Join(tmp, "sql", "queries", "myschema_users", "myschema_users_gen.sql") + assert.FileExists(t, sqlPath, "CRUD SQL file should use schema-prefixed Go name for path") + + // Schema-qualified dot path should NOT exist + assert.NoFileExists(t, filepath.Join(tmp, "sql", "queries", "myschema.users", "myschema.users_gen.sql")) + // Bare name path should NOT exist either (schema prefix required) + assert.NoFileExists(t, filepath.Join(tmp, "sql", "queries", "users", "users_gen.sql")) + + // SQL content should use schema-qualified name + data, err := os.ReadFile(sqlPath) + require.NoError(t, err) + content := string(data) + assert.Contains(t, content, "myschema.users", "SQL should contain schema-qualified table name") + assert.Contains(t, content, "INSERT INTO myschema.users", "INSERT should use schema-qualified name") + assert.Contains(t, content, "SELECT * FROM myschema.users", "SELECT should use schema-qualified name") + + // Method names should use schema-prefixed Go name + assert.Contains(t, content, "CreateMyschemaUser", "method name should use schema-prefixed name") + assert.Contains(t, content, "GetMyschemaUser", "method name should use schema-prefixed name") +} + +// TestGenerate_SameNameTablesInDifferentSchemas verifies that two tables +// with the same bare name in different schemas produce distinct output files, +// method names, and SQL with correct schema qualification. +func TestGenerate_SameNameTablesInDifferentSchemas(t *testing.T) { + tmp := t.TempDir() + migrationsDir := filepath.Join(tmp, "sql", "migrations") + require.NoError(t, os.MkdirAll(migrationsDir, 0o755)) + + schemaSQL := `CREATE SCHEMA shop; +CREATE TABLE orders ( + id TEXT PRIMARY KEY, + total TEXT NOT NULL +); +CREATE TABLE shop.orders ( + id TEXT PRIMARY KEY, + amount TEXT NOT NULL +);` + require.NoError(t, os.WriteFile(filepath.Join(migrationsDir, "001_init.sql"), []byte(schemaSQL), 0o644)) + + configPath := filepath.Join(tmp, "pgxgen.yaml") + cfg := &config.V2Config{ + Version: "2", + Schemas: []config.SchemaConfig{ + { + Name: "main", + Engine: "postgresql", + SchemaDir: "sql/migrations", + Defaults: &config.DefaultsConfig{ + QueriesDirPrefix: "sql/queries", + OutputDirPrefix: "internal/store/repos", + }, + Tables: map[string]config.TableConfig{ + "orders": { + PrimaryColumn: "id", + Crud: &config.TableCrudConfig{ + Methods: map[string]*config.MethodConfig{ + "get": {}, + "create": {Returning: "*"}, + }, + }, + }, + "shop.orders": { + PrimaryColumn: "id", + Crud: &config.TableCrudConfig{ + Methods: map[string]*config.MethodConfig{ + "get": {}, + "create": {Returning: "*"}, + }, + }, + }, + }, + }, + }, + } + + orch := NewOrchestrator(logger.New(), cfg, configPath) + _, err := orch.Generate(context.Background(), GenerateOpts{Targets: []string{"crud"}}) + require.NoError(t, err) + + // Public schema table — bare name paths + publicPath := filepath.Join(tmp, "sql", "queries", "orders", "orders_gen.sql") + assert.FileExists(t, publicPath) + + // Shop schema table — schema-prefixed paths + shopPath := filepath.Join(tmp, "sql", "queries", "shop_orders", "shop_orders_gen.sql") + assert.FileExists(t, shopPath) + + // Verify public schema SQL + publicData, err := os.ReadFile(publicPath) + require.NoError(t, err) + publicSQL := string(publicData) + assert.Contains(t, publicSQL, "SELECT * FROM orders") + assert.Contains(t, publicSQL, "INSERT INTO orders") + assert.Contains(t, publicSQL, "-- name: CreateOrder :one") + assert.Contains(t, publicSQL, "-- name: GetOrder :one") + + // Verify shop schema SQL + shopData, err := os.ReadFile(shopPath) + require.NoError(t, err) + shopSQL := string(shopData) + assert.Contains(t, shopSQL, "SELECT * FROM shop.orders") + assert.Contains(t, shopSQL, "INSERT INTO shop.orders") + assert.Contains(t, shopSQL, "-- name: CreateShopOrder :one") + assert.Contains(t, shopSQL, "-- name: GetShopOrder :one") +} + // TestGenerate_DryRunSkipsWritesAndSqlc verifies that dry-run touches // nothing on disk and does not invoke sqlc (which has no preview mode). func TestGenerate_DryRunSkipsWritesAndSqlc(t *testing.T) { diff --git a/internal/codegen/sqlc/config.go b/internal/codegen/sqlc/config.go index b03eeb4..2eb7206 100644 --- a/internal/codegen/sqlc/config.go +++ b/internal/codegen/sqlc/config.go @@ -116,13 +116,16 @@ func BuildSqlcConfig(schema *config.SchemaConfig) *sqlcConfig { for _, tableName := range tableNames { tableConfig := schema.Tables[tableName] + parts := config.ParseTableKey(tableName) - queriesDir := schema.ResolveQueriesDir(tableName) + // Use Go name for directory resolution (schema-prefixed when applicable) + goName := parts.GoName() + queriesDir := schema.ResolveQueriesDir(goName) if tableConfig.QueriesDir != "" { queriesDir = tableConfig.QueriesDir } - outputDir := schema.ResolveOutputDir(tableName) + outputDir := schema.ResolveOutputDir(goName) if tableConfig.OutputDir != "" { outputDir = tableConfig.OutputDir } diff --git a/internal/codegen/sqlc/config_test.go b/internal/codegen/sqlc/config_test.go new file mode 100644 index 0000000..36972b9 --- /dev/null +++ b/internal/codegen/sqlc/config_test.go @@ -0,0 +1,77 @@ +package sqlc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tkcrm/pgxgen/internal/config" +) + +func TestBuildSqlcConfig_SameNamedTablesInDifferentSchemas(t *testing.T) { + schema := &config.SchemaConfig{ + Engine: "postgresql", + SchemaDir: "sql/migrations", + Defaults: &config.DefaultsConfig{ + QueriesDirPrefix: "sql/queries", + OutputDirPrefix: "internal/store/repos", + }, + Tables: map[string]config.TableConfig{ + "orders": { + PrimaryColumn: "id", + }, + "shop.orders": { + PrimaryColumn: "id", + }, + }, + } + + cfg := BuildSqlcConfig(schema) + require.Len(t, cfg.SQL, 2) + + // Find entries by queries path (sorted alphabetically: "orders" before "shop.orders") + var publicEntry, shopEntry sqlcSQLEntry + for _, entry := range cfg.SQL { + if entry.Queries == "../sql/queries/orders" { + publicEntry = entry + } + if entry.Queries == "../sql/queries/shop_orders" { + shopEntry = entry + } + } + + // Public schema table — bare name directories + assert.Equal(t, "../sql/queries/orders", publicEntry.Queries) + assert.Equal(t, "../internal/store/repos/orders", publicEntry.Gen.Go.Out) + + // Shop schema table — schema-prefixed directories + assert.Equal(t, "../sql/queries/shop_orders", shopEntry.Queries) + assert.Equal(t, "../internal/store/repos/shop_orders", shopEntry.Gen.Go.Out) + + // Both use the same schema dir + assert.Equal(t, "../sql/migrations", publicEntry.Schema) + assert.Equal(t, "../sql/migrations", shopEntry.Schema) +} + +func TestBuildSqlcConfig_PlainTableUnchanged(t *testing.T) { + schema := &config.SchemaConfig{ + Engine: "postgresql", + SchemaDir: "sql/migrations", + Defaults: &config.DefaultsConfig{ + QueriesDirPrefix: "sql/queries", + OutputDirPrefix: "internal/store/repos", + }, + Tables: map[string]config.TableConfig{ + "users": { + PrimaryColumn: "id", + }, + }, + } + + cfg := BuildSqlcConfig(schema) + require.Len(t, cfg.SQL, 1) + + entry := cfg.SQL[0] + assert.Equal(t, "../sql/queries/users", entry.Queries) + assert.Equal(t, "../internal/store/repos/users", entry.Gen.Go.Out) +} diff --git a/internal/config/v2.go b/internal/config/v2.go index 46b32ed..fb3e259 100644 --- a/internal/config/v2.go +++ b/internal/config/v2.go @@ -1,5 +1,7 @@ package config +import "strings" + // V2Config is the root configuration for pgxgen v2. type V2Config struct { Version string `yaml:"version" validate:"required,eq=2"` @@ -247,3 +249,43 @@ func (s *SchemaConfig) ResolveOutputDir(tableName string) string { func (s *SchemaConfig) IsPerTableMode() bool { return s.Defaults != nil && s.Defaults.QueriesDirPrefix != "" } + +// TableKeyParts holds the parsed components of a table key from pgxgen.yaml. +// A key like "shop.orders" is split into Schema="shop", Table="orders". +// A key like "orders" is split into Schema="", Table="orders". +type TableKeyParts struct { + Schema string // database schema name, empty for default/public + Table string // bare table name without schema prefix +} + +// SQLTableName returns the table name as it should appear in SQL statements. +// If a schema is specified, it returns "schema.table"; otherwise just "table". +func (p TableKeyParts) SQLTableName() string { + if p.Schema != "" { + return p.Schema + "." + p.Table + } + return p.Table +} + +// GoName returns the name used for Go identifiers, file paths, and directories. +// For schema-qualified keys, it prepends the schema name with an underscore +// (e.g. "shop_orders") to avoid collisions between same-named tables in +// different schemas. For plain table keys, it returns the bare table name. +func (p TableKeyParts) GoName() string { + if p.Schema != "" { + return p.Schema + "_" + p.Table + } + return p.Table +} + +// ParseTableKey splits a pgxgen.yaml table map key into schema and table parts. +// Supports "schema.table" (dot-separated) and plain "table" formats. +func ParseTableKey(key string) TableKeyParts { + if idx := strings.LastIndex(key, "."); idx != -1 { + return TableKeyParts{ + Schema: key[:idx], + Table: key[idx+1:], + } + } + return TableKeyParts{Table: key} +} diff --git a/internal/config/v2_test.go b/internal/config/v2_test.go index 320201d..7d85076 100644 --- a/internal/config/v2_test.go +++ b/internal/config/v2_test.go @@ -82,6 +82,38 @@ func TestResolveOutputDir(t *testing.T) { } } +func TestParseTableKey(t *testing.T) { + tests := []struct { + key string + wantSchema string + wantTable string + wantSQL string + wantGoName string + }{ + {key: "users", wantSchema: "", wantTable: "users", wantSQL: "users", wantGoName: "users"}, + {key: "shop.orders", wantSchema: "shop", wantTable: "orders", wantSQL: "shop.orders", wantGoName: "shop_orders"}, + {key: "my_schema.my_table", wantSchema: "my_schema", wantTable: "my_table", wantSQL: "my_schema.my_table", wantGoName: "my_schema_my_table"}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + parts := ParseTableKey(tt.key) + if parts.Schema != tt.wantSchema { + t.Errorf("ParseTableKey(%q).Schema = %q, want %q", tt.key, parts.Schema, tt.wantSchema) + } + if parts.Table != tt.wantTable { + t.Errorf("ParseTableKey(%q).Table = %q, want %q", tt.key, parts.Table, tt.wantTable) + } + if parts.SQLTableName() != tt.wantSQL { + t.Errorf("ParseTableKey(%q).SQLTableName() = %q, want %q", tt.key, parts.SQLTableName(), tt.wantSQL) + } + if parts.GoName() != tt.wantGoName { + t.Errorf("ParseTableKey(%q).GoName() = %q, want %q", tt.key, parts.GoName(), tt.wantGoName) + } + }) + } +} + func TestResolveQueriesDir(t *testing.T) { tests := []struct { name string diff --git a/internal/sqlparser/catalog/catalog.go b/internal/sqlparser/catalog/catalog.go index 15bc7f2..160dd4d 100644 --- a/internal/sqlparser/catalog/catalog.go +++ b/internal/sqlparser/catalog/catalog.go @@ -93,5 +93,6 @@ type CheckConstraint struct { // Enum represents a database enum type. type Enum struct { Name string + Schema string Values []string } diff --git a/internal/sqlparser/postgresql.go b/internal/sqlparser/postgresql.go index 642f2a6..f73cd74 100644 --- a/internal/sqlparser/postgresql.go +++ b/internal/sqlparser/postgresql.go @@ -275,7 +275,8 @@ func (p *postgresParser) handleCreateEnum(cat *catalog.Catalog, n *pg.CreateEnum schema := p.getOrCreateSchema(cat, schemaName) enum := &catalog.Enum{ - Name: enumName, + Name: enumName, + Schema: schemaName, } for _, val := range n.Vals { if s, ok := val.Node.(*pg.Node_String_); ok { diff --git a/internal/sqlparser/typemap/postgresql.go b/internal/sqlparser/typemap/postgresql.go index eac8562..ec40ca7 100644 --- a/internal/sqlparser/typemap/postgresql.go +++ b/internal/sqlparser/typemap/postgresql.go @@ -379,15 +379,55 @@ func (m *postgresMapper) GoType(col *catalog.Column, enums []*catalog.Enum, opts return "interface{}" default: - // Check for enum types + // Check for enum types. + // Column type may be schema-qualified (e.g. "shop.order_status"). + // When schema-qualified, prefer an exact schema+name match to avoid + // collisions between same-named enums in different schemas. + defaultSchema := opts.DefaultSchema + if defaultSchema == "" { + defaultSchema = "public" + } + + colSchema := "" + bareColumnType := columnType + if idx := strings.LastIndex(columnType, "."); idx != -1 { + colSchema = columnType[:idx] + bareColumnType = columnType[idx+1:] + } + + // First pass: if column type is schema-qualified, find exact schema match + if colSchema != "" { + for _, enum := range enums { + if strings.EqualFold(enum.Name, bareColumnType) && strings.EqualFold(enum.Schema, colSchema) { + goName := enumGoName(enum, defaultSchema) + if notNull { + return structName(goName) + } + return "Null" + structName(goName) + } + } + } + + // Second pass: match by bare name (backward-compatible for non-qualified types) for _, enum := range enums { - if strings.EqualFold(enum.Name, columnType) { + if strings.EqualFold(enum.Name, columnType) || strings.EqualFold(enum.Name, bareColumnType) { + goName := enumGoName(enum, defaultSchema) if notNull { - return structName(enum.Name) + return structName(goName) } - return "Null" + structName(enum.Name) + return "Null" + structName(goName) } } return "interface{}" } } + +// enumGoName returns the name to use for Go type generation. +// For enums in non-default schemas, it prepends the schema name to match +// sqlc's convention (e.g. schema "subscription", enum "order_status" → "subscription_order_status"). +func enumGoName(enum *catalog.Enum, defaultSchema string) string { + if enum.Schema != "" && enum.Schema != defaultSchema { + return enum.Schema + "_" + enum.Name + } + return enum.Name +} diff --git a/internal/sqlparser/typemap/postgresql_test.go b/internal/sqlparser/typemap/postgresql_test.go index 6c9d463..aa0fa5b 100644 --- a/internal/sqlparser/typemap/postgresql_test.go +++ b/internal/sqlparser/typemap/postgresql_test.go @@ -259,6 +259,53 @@ func TestPostgresEnumType(t *testing.T) { } } +func TestPostgresEnumType_NonDefaultSchema(t *testing.T) { + m := &postgresMapper{} + enums := []*catalog.Enum{ + {Name: "order_status", Schema: "shop", Values: []string{"pending", "shipped"}}, + } + opts := Options{DefaultSchema: "public"} + + // NOT NULL enum in non-default schema → prefixed Go type + got := m.GoType(col("order_status", true), enums, opts) + if got != "ShopOrderStatus" { + t.Errorf("non-default schema enum NOT NULL = %q, want ShopOrderStatus", got) + } + + // NULL enum in non-default schema → Null-prefixed + got = m.GoType(col("order_status", false), enums, opts) + if got != "NullShopOrderStatus" { + t.Errorf("non-default schema enum NULL = %q, want NullShopOrderStatus", got) + } +} + +func TestPostgresEnumType_DefaultSchemaNotPrefixed(t *testing.T) { + m := &postgresMapper{} + enums := []*catalog.Enum{ + {Name: "user_status", Schema: "public", Values: []string{"active", "inactive"}}, + } + opts := Options{DefaultSchema: "public"} + + got := m.GoType(col("user_status", true), enums, opts) + if got != "UserStatus" { + t.Errorf("default schema enum = %q, want UserStatus (no prefix)", got) + } +} + +func TestPostgresEnumType_SchemaQualifiedColumnType(t *testing.T) { + m := &postgresMapper{} + enums := []*catalog.Enum{ + {Name: "order_status", Schema: "shop", Values: []string{"pending", "shipped"}}, + } + opts := Options{DefaultSchema: "public"} + + // Column type is schema-qualified (e.g. "shop.order_status") + got := m.GoType(col("shop.order_status", true), enums, opts) + if got != "ShopOrderStatus" { + t.Errorf("schema-qualified column type = %q, want ShopOrderStatus", got) + } +} + func TestPostgresNullablePointers(t *testing.T) { m := &postgresMapper{} opts := Options{SqlPackage: "pgx/v5", EmitPointersForNull: true} diff --git a/internal/sqlparser/typemap/typemap.go b/internal/sqlparser/typemap/typemap.go index 014f2ad..6515e8c 100644 --- a/internal/sqlparser/typemap/typemap.go +++ b/internal/sqlparser/typemap/typemap.go @@ -11,6 +11,7 @@ import ( type Options struct { SqlPackage string // "pgx/v5", "pgx/v4", "database/sql" EmitPointersForNull bool + DefaultSchema string // default database schema (e.g. "public" for PostgreSQL) } // TypeMapper maps SQL column types to Go types.