From 56ed1ad613dbea52b29278c107b11e77f259709d Mon Sep 17 00:00:00 2001 From: rubin Date: Thu, 19 Feb 2026 18:04:42 -0300 Subject: [PATCH] wip: validate column names in ON CONFLICT DO UPDATE SET - Add OnConflictClause validation in internal/sql/validate/on_conflict.go - Integrate validation into analyze.go pipeline - Add unit tests (5 cases) - Add endtoend testcase for invalid column TODO: - Test qualified column names (cart_items.col) - Test composite expressions on right side of SET - Test WHERE clause in ON CONFLICT - Test explicit schema (public.cart_items) --- internal/compiler/analyze.go | 3 + .../postgresql/pgx/v5/query.sql | 6 + .../postgresql/pgx/v5/schema.sql | 7 + .../postgresql/pgx/v5/sqlc.json | 17 +++ .../postgresql/pgx/v5/stderr.txt | 2 + internal/sql/validate/on_conflict.go | 106 ++++++++++++++ internal/sql/validate/on_conflict_test.go | 130 ++++++++++++++++++ 7 files changed, 271 insertions(+) create mode 100644 internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/query.sql create mode 100644 internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/schema.sql create mode 100644 internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/sqlc.json create mode 100644 internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/stderr.txt create mode 100644 internal/sql/validate/on_conflict.go create mode 100644 internal/sql/validate/on_conflict_test.go diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..4d7c924391 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -152,6 +152,9 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } + if err := check(validate.OnConflictClause(c.catalog, n, table)); err != nil { + return nil, err + } } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { diff --git a/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/query.sql new file mode 100644 index 0000000000..91fd1a9289 --- /dev/null +++ b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/query.sql @@ -0,0 +1,6 @@ +-- name: AddItem :exec +INSERT INTO cart_items (owner_id, product_id, price_amount, price_currency) +VALUES ($1, $2, $3, $4) +ON CONFLICT (owner_id, product_id) DO UPDATE + SET price_amount1 = EXCLUDED.price_amount1, + price_currency = EXCLUDED.price_currency; diff --git a/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/schema.sql b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/schema.sql new file mode 100644 index 0000000000..e487943b55 --- /dev/null +++ b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/schema.sql @@ -0,0 +1,7 @@ +CREATE TABLE cart_items ( + owner_id VARCHAR(255) NOT NULL, + product_id UUID NOT NULL, + price_amount DECIMAL NOT NULL, + price_currency VARCHAR(3) NOT NULL, + PRIMARY KEY (owner_id, product_id) +); diff --git a/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/sqlc.json b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/sqlc.json new file mode 100644 index 0000000000..0bb9b4b3db --- /dev/null +++ b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/sqlc.json @@ -0,0 +1,17 @@ +{ + "version": "2", + "sql": [ + { + "engine": "postgresql", + "schema": "schema.sql", + "queries": "query.sql", + "gen": { + "go": { + "package": "querytest", + "out": "go", + "sql_package": "pgx/v5" + } + } + } + ] +} diff --git a/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/stderr.txt b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/stderr.txt new file mode 100644 index 0000000000..f5fbe541ae --- /dev/null +++ b/internal/endtoend/testdata/on_conflict_invalid_column/postgresql/pgx/v5/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:5:9: column "price_amount1" of relation "cart_items" does not exist diff --git a/internal/sql/validate/on_conflict.go b/internal/sql/validate/on_conflict.go new file mode 100644 index 0000000000..cca82f1fd8 --- /dev/null +++ b/internal/sql/validate/on_conflict.go @@ -0,0 +1,106 @@ +package validate + +import ( + "fmt" + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +func OnConflictClause(cat *catalog.Catalog, stmt *ast.InsertStmt, tableName *ast.TableName) error { + if stmt.OnConflictClause == nil { + return nil + } + + occ := stmt.OnConflictClause + + if occ.Action != ast.OnConflictActionUpdate { + return nil + } + + if tableName == nil { + return nil + } + + tbl, err := cat.GetTable(tableName) + if err != nil { + return err + } + + relName := "" + if tbl.Rel != nil { + relName = tbl.Rel.Name + } + + validCols := make(map[string]struct{}, len(tbl.Columns)) + for _, c := range tbl.Columns { + validCols[strings.ToLower(c.Name)] = struct{}{} + } + + if occ.TargetList == nil { + return nil + } + + for _, item := range occ.TargetList.Items { + res, ok := item.(*ast.ResTarget) + if !ok { + continue + } + + if res.Name != nil { + colName := strings.ToLower(*res.Name) + if _, exists := validCols[colName]; !exists { + return &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("column %q of relation %q does not exist", *res.Name, relName), + Location: res.Location, + } + } + } + + if res.Val != nil { + if err := validateExcludedRefs(res.Val, validCols, relName); err != nil { + return err + } + } + } + + return nil +} + +func validateExcludedRefs(node ast.Node, validCols map[string]struct{}, tableName string) error { + refs := astutils.Search(node, func(n ast.Node) bool { + _, ok := n.(*ast.ColumnRef) + return ok + }) + + for _, ref := range refs.Items { + colRef, ok := ref.(*ast.ColumnRef) + if !ok { + continue + } + + parts := make([]string, 0, len(colRef.Fields.Items)) + for _, field := range colRef.Fields.Items { + if s, ok := field.(*ast.String); ok { + parts = append(parts, s.Str) + } + } + + if len(parts) == 2 && strings.ToLower(parts[0]) == "excluded" { + colName := strings.ToLower(parts[1]) + if _, exists := validCols[colName]; !exists { + return &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("column %q does not exist in relation %q (via EXCLUDED)", parts[1], tableName), + Location: colRef.Location, + } + } + } + } + + return nil +} diff --git a/internal/sql/validate/on_conflict_test.go b/internal/sql/validate/on_conflict_test.go new file mode 100644 index 0000000000..e5212ad031 --- /dev/null +++ b/internal/sql/validate/on_conflict_test.go @@ -0,0 +1,130 @@ +package validate + +import ( + "strings" + "testing" + + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func makeTestCatalog(t *testing.T) (*catalog.Catalog, *ast.TableName) { + t.Helper() + + p := postgresql.NewParser() + stmts, err := p.Parse(strings.NewReader(` + CREATE TABLE cart_items ( + owner_id VARCHAR(255) NOT NULL, + product_id UUID NOT NULL, + price_amount DECIMAL NOT NULL, + price_currency VARCHAR(3) NOT NULL, + PRIMARY KEY (owner_id, product_id) + ); + `)) + if err != nil { + t.Fatalf("parse schema: %v", err) + } + + cat := catalog.New("public") + for _, stmt := range stmts { + if err := cat.Update(stmt, nil); err != nil { + t.Fatalf("update catalog: %v", err) + } + } + + tableName := &ast.TableName{Schema: "public", Name: "cart_items"} + return cat, tableName +} + +func makeStmt(action ast.OnConflictAction, setItems []struct{ col, val string }) *ast.InsertStmt { + stmt := &ast.InsertStmt{ + Relation: &ast.RangeVar{ + Schemaname: strPtr("public"), + Relname: strPtr("cart_items"), + }, + } + + if action == ast.OnConflictActionNone { + return stmt + } + + items := make([]ast.Node, 0, len(setItems)) + for _, si := range setItems { + colName := si.col + items = append(items, &ast.ResTarget{ + Name: &colName, + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "excluded"}, + &ast.String{Str: si.val}, + }, + }, + }, + }) + } + + stmt.OnConflictClause = &ast.OnConflictClause{ + Action: action, + TargetList: &ast.List{Items: items}, + } + return stmt +} + +func strPtr(s string) *string { return &s } + +func TestOnConflictClause(t *testing.T) { + cat, tableName := makeTestCatalog(t) + + tests := []struct { + name string + stmt *ast.InsertStmt + wantErr bool + }{ + { + name: "valid columns in SET and EXCLUDED", + stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{ + {"price_amount", "price_amount"}, + {"price_currency", "price_currency"}, + }), + wantErr: false, + }, + { + name: "invalid column on left side of SET", + stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{ + {"price_amount1", "price_amount"}, + }), + wantErr: true, + }, + { + name: "invalid EXCLUDED reference on right side", + stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{ + {"price_amount", "price_amount1"}, + }), + wantErr: true, + }, + { + name: "DO NOTHING skips column validation", + stmt: makeStmt(ast.OnConflictActionNothing, nil), + wantErr: false, + }, + { + name: "no OnConflictClause passes without error", + stmt: makeStmt(ast.OnConflictActionNone, nil), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := OnConflictClause(cat, tt.stmt, tableName) + if tt.wantErr && err == nil { + t.Error("expected error but got none") + } + if !tt.wantErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +}