Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,47 @@ tables:
find: {} # → auto adds WHERE deleted_at IS NULL
```

### Schema-qualified table names

For tables in non-default schemas, use dot notation in the table key. The migration files must include `CREATE SCHEMA` before creating objects in that schema:

```sql
-- sql/migrations/001_init.sql
CREATE SCHEMA shop;

CREATE TABLE users ( id UUID PRIMARY KEY, email TEXT NOT NULL );
CREATE TABLE shop.orders ( id UUID PRIMARY KEY, total NUMERIC NOT NULL );

CREATE TYPE shop.order_status AS ENUM ('pending', 'shipped', 'delivered');
```

```yaml
tables:
users: # public schema (default)
primary_column: id
crud:
methods:
get: {}
create: { returning: "*" }

shop.orders: # "shop" schema
primary_column: id
crud:
methods:
get: {}
create: { returning: "*" }
```

When a schema prefix is specified:

- **SQL** uses the schema-qualified name: `INSERT INTO shop.orders`, `SELECT * FROM shop.orders`
- **Go identifiers** are prefixed with the schema name: `CreateShopOrder`, `GetShopOrder`
- **File paths** are prefixed with the schema name: `sql/queries/shop_orders/shop_orders_gen.sql`
- **Go types** (models, enums) are prefixed to match sqlc's naming convention: `shop.orders` → `ShopOrder`, `shop.order_status` → `ShopOrderStatus`
- **Cross-schema references** are resolved correctly — a column of type `shop.order_status` in any schema resolves to `ShopOrderStatus`

This ensures no collisions when the same table name exists in different schemas. Tables in the default schema (`public` for PostgreSQL, `main` for SQLite) remain unprefixed.

## Commands

| Command | Description |
Expand Down
35 changes: 29 additions & 6 deletions internal/codegen/constants/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"go/format"
"path/filepath"
"sort"
"strings"

"github.com/gobeam/stringy"
"github.com/tkcrm/pgxgen/internal/config"
Expand Down Expand Up @@ -55,6 +56,7 @@ func generateConstants(

for _, tableName := range tableNames {
tableConfig := schema.Tables[tableName]
parts := config.ParseTableKey(tableName)

includeColumns := defaultIncludeColumns
hasConstants := tableConfig.Constants != nil
Expand All @@ -65,7 +67,8 @@ func generateConstants(
continue
}

outputDir := schema.ResolveOutputDir(tableName)
// Use Go name for directory resolution
outputDir := schema.ResolveOutputDir(parts.GoName())
if tableConfig.OutputDir != "" {
outputDir = tableConfig.OutputDir
}
Expand Down Expand Up @@ -103,18 +106,24 @@ func generateConstants(
}

for _, entry := range entries {
tablePreffix := stringy.New(entry.tableName).CamelCase().UcFirst()
parts := config.ParseTableKey(entry.tableName)
goName := parts.GoName()
// Use Go name for identifier prefix (schema-prefixed when applicable)
tablePreffix := stringy.New(goName).CamelCase().UcFirst()
// Use schema-qualified name for constant value (SQL usage)
sqlName := parts.SQLTableName()
params.Tables = append(params.Tables, ConstantsTableNamesParamsItem{
NamePreffix: tablePreffix,
Name: entry.tableName,
NamePreffix: tablePreffix,
Name: sqlName,
BareTableName: goName,
})

if entry.includeColumns {
columns := getTableColumns(cat, entry.tableName)
for _, col := range columns {
colPreffix := tablePreffix + stringy.New(col).CamelCase().UcFirst()
params.ColumnNames = append(params.ColumnNames, ConstantsColumnNamesParamsItem{
TableName: entry.tableName,
TableName: sqlName,
NamePreffix: colPreffix,
Name: col,
})
Expand Down Expand Up @@ -144,9 +153,14 @@ func generateConstants(
}

func getTableColumns(cat *catalog.Catalog, tableName string) []string {
schemaName, bareName := splitSchemaTable(tableName)

for _, s := range cat.Schemas {
if schemaName != "" && s.Name != schemaName {
continue
}
for _, t := range s.Tables {
if t.Name == tableName {
if t.Name == bareName {
cols := make([]string, len(t.Columns))
for i, c := range t.Columns {
cols[i] = c.Name
Expand All @@ -157,3 +171,12 @@ func getTableColumns(cat *catalog.Catalog, tableName string) []string {
}
return nil
}

// splitSchemaTable splits "schema.table" into ("schema", "table").
// For a bare "table" name it returns ("", "table").
func splitSchemaTable(name string) (string, string) {
if idx := strings.LastIndex(name, "."); idx != -1 {
return name[:idx], name[idx+1:]
}
return "", name
}
7 changes: 4 additions & 3 deletions internal/codegen/constants/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
)

type ConstantsTableNamesParamsItem struct {
NamePreffix string
Name string
NamePreffix string
Name string
BareTableName string
}

type ConstantsColumnNamesParamsItem struct {
Expand Down Expand Up @@ -94,7 +95,7 @@ func (s ColumnNames) Strings() []string {
content.WriteString(")\n\n")

for _, tableName := range p.Tables {
fmt.Fprintf(&content, "func %sColumnNames() ColumnNames {\n", utils.ToPascalCase(tableName.Name))
fmt.Fprintf(&content, "func %sColumnNames() ColumnNames {\n", utils.ToPascalCase(tableName.BareTableName))
content.WriteString("return ColumnNames{\n")
for _, item := range p.GetColumnsForTable(tableName.Name) {
fmt.Fprintf(&content, "ColumnName%s,\n", item.NamePreffix)
Expand Down
73 changes: 73 additions & 0 deletions internal/codegen/constants/render_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package constants

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRenderConstants_SameNamedTablesInDifferentSchemas(t *testing.T) {
params := ConstantsParams{
Package: "store",
Tables: []ConstantsTableNamesParamsItem{
{NamePreffix: "Orders", Name: "orders", BareTableName: "orders"},
{NamePreffix: "ShopOrders", Name: "shop.orders", BareTableName: "shop_orders"},
},
ColumnNames: []ConstantsColumnNamesParamsItem{
{TableName: "orders", NamePreffix: "OrdersId", Name: "id"},
{TableName: "orders", NamePreffix: "OrdersTotal", Name: "total"},
{TableName: "shop.orders", NamePreffix: "ShopOrdersId", Name: "id"},
{TableName: "shop.orders", NamePreffix: "ShopOrdersAmount", Name: "amount"},
},
}

var buf bytes.Buffer
err := RenderConstants(params, &buf)
require.NoError(t, err)
output := buf.String()

// Distinct table name constants
assert.Contains(t, output, `TableNameOrders TableName = "orders"`)
assert.Contains(t, output, `TableNameShopOrders TableName = "shop.orders"`)

// Distinct column name constants — no collision
assert.Contains(t, output, `ColumnNameOrdersId ColumnName = "id"`)
assert.Contains(t, output, `ColumnNameOrdersTotal ColumnName = "total"`)
assert.Contains(t, output, `ColumnNameShopOrdersId ColumnName = "id"`)
assert.Contains(t, output, `ColumnNameShopOrdersAmount ColumnName = "amount"`)

// Distinct ColumnNames() functions
assert.Contains(t, output, "func OrdersColumnNames() ColumnNames {")
assert.Contains(t, output, "func ShopOrdersColumnNames() ColumnNames {")

// Each function returns only its own columns
// OrdersColumnNames should contain OrdersId, OrdersTotal
// ShopOrdersColumnNames should contain ShopOrdersId, ShopOrdersAmount
assert.Contains(t, output, "ColumnNameOrdersId,\nColumnNameOrdersTotal,")
assert.Contains(t, output, "ColumnNameShopOrdersId,\nColumnNameShopOrdersAmount,")
}

func TestRenderConstants_SingleSchemaUnchanged(t *testing.T) {
params := ConstantsParams{
Package: "store",
Tables: []ConstantsTableNamesParamsItem{
{NamePreffix: "Users", Name: "users", BareTableName: "users"},
},
ColumnNames: []ConstantsColumnNamesParamsItem{
{TableName: "users", NamePreffix: "UsersId", Name: "id"},
{TableName: "users", NamePreffix: "UsersName", Name: "name"},
},
}

var buf bytes.Buffer
err := RenderConstants(params, &buf)
require.NoError(t, err)
output := buf.String()

// Plain table — no schema prefix
assert.Contains(t, output, `TableNameUsers TableName = "users"`)
assert.Contains(t, output, `ColumnNameUsersId ColumnName = "id"`)
assert.Contains(t, output, "func UsersColumnNames() ColumnNames {")
}
38 changes: 29 additions & 9 deletions internal/codegen/crud/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ func New(eng engine.Engine) *Generator {
}

// GenerateTable generates all CRUD SQL for a single table.
// sqlTableName is the schema-qualified name used in SQL (e.g. "shop.orders"),
// bareTableName is the plain name used for Go identifiers (e.g. "orders").
func (g *Generator) GenerateTable(
tableName string,
sqlTableName string,
bareTableName string,
tableConfig config.TableConfig,
defaultCrud *config.CrudDefaultsConfig,
columns []string,
Expand All @@ -72,15 +75,15 @@ func (g *Generator) GenerateTable(
}

data, templateName, err := g.buildTemplateData(
tableName, methodName, methodCfg, tableConfig, defaultCrud, columns,
sqlTableName, bareTableName, methodName, methodCfg, tableConfig, defaultCrud, columns,
)
if err != nil {
return nil, fmt.Errorf("build template data for %s.%s: %w", tableName, methodName, err)
return nil, fmt.Errorf("build template data for %s.%s: %w", sqlTableName, methodName, err)
}

rendered, err := g.renderTemplate(templateName, data)
if err != nil {
return nil, fmt.Errorf("render template %s for %s.%s: %w", templateName, tableName, methodName, err)
return nil, fmt.Errorf("render template %s for %s.%s: %w", templateName, sqlTableName, methodName, err)
}

buf.Write(rendered)
Expand All @@ -96,21 +99,21 @@ func (g *Generator) GenerateTable(
}

func (g *Generator) buildTemplateData(
tableName, methodName string,
sqlTableName, bareTableName, methodName string,
methodCfg *config.MethodConfig,
tableConfig config.TableConfig,
defaultCrud *config.CrudDefaultsConfig,
allColumns []string,
) (*TemplateData, string, error) {
data := &TemplateData{
TableName: tableName,
TableName: sqlTableName,
PrimaryColumn: tableConfig.PrimaryColumn,
}

// Resolve method name
// Resolve method name using bare table name (no schema prefix in Go identifiers)
data.MethodName = methodCfg.Name
if data.MethodName == "" {
data.MethodName = g.defaultMethodName(methodName, tableName, defaultCrud)
data.MethodName = g.defaultMethodName(methodName, bareTableName, defaultCrud)
}

// Soft delete
Expand Down Expand Up @@ -400,10 +403,18 @@ func cloneWhere(w map[string]config.WhereParamConfig) map[string]config.WherePar
}

// GetTableColumns extracts column names from a catalog table.
// tableName can be a bare name ("users") or schema-qualified ("shop.users").
// When schema-qualified, it matches only the specific schema; when bare,
// it searches all schemas (backward-compatible behavior).
func GetTableColumns(cat *catalog.Catalog, tableName string) []string {
schemaName, bareName := splitSchemaTable(tableName)

for _, schema := range cat.Schemas {
if schemaName != "" && schema.Name != schemaName {
continue
}
for _, table := range schema.Tables {
if table.Name == tableName {
if table.Name == bareName {
columns := make([]string, len(table.Columns))
for i, col := range table.Columns {
columns[i] = col.Name
Expand All @@ -414,3 +425,12 @@ func GetTableColumns(cat *catalog.Catalog, tableName string) []string {
}
return nil
}

// splitSchemaTable splits "schema.table" into ("schema", "table").
// For a bare "table" name it returns ("", "table").
func splitSchemaTable(name string) (string, string) {
if idx := strings.LastIndex(name, "."); idx != -1 {
return name[:idx], name[idx+1:]
}
return "", name
}
Loading