diff --git a/lsmysql/mysql.go b/lsmysql/mysql.go index 916404c..1dc9d82 100644 --- a/lsmysql/mysql.go +++ b/lsmysql/mysql.go @@ -391,11 +391,21 @@ func (p *MySQL) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Lo return err } if schema != "" { - _, err := d.DB().ExecContext(ctx, fmt.Sprintf(` - CREATE SCHEMA IF NOT EXISTS %s - `, schema)) + var schemaExists bool + err = d.DB().QueryRowContext(ctx, + `SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)`, + schema, + ).Scan(&schemaExists) if err != nil { - return errors.Wrapf(err, "could not create libschema schema '%s'", schema) + return errors.Wrapf(err, "could not check if libschema schema '%s' exists", schema) + } + if !schemaExists { + _, err := d.DB().ExecContext(ctx, fmt.Sprintf(` + CREATE SCHEMA IF NOT EXISTS %s + `, schema)) + if err != nil { + return errors.Wrapf(err, "could not create libschema schema '%s'", schema) + } } } _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` diff --git a/lsmysql/mysql_test.go b/lsmysql/mysql_test.go index 1f4eec0..1504d56 100644 --- a/lsmysql/mysql_test.go +++ b/lsmysql/mysql_test.go @@ -91,6 +91,9 @@ func testMysqlHappyPath(t *testing.T, dsn string, createPostfix string, driverNe }() defer cleanup(db) + _, err = db.Exec("CREATE SCHEMA " + options.SchemaOverride) + require.NoError(t, err) + s := libschema.New(context.Background(), options) dbase, _, err := driverNew(t, "test", s, db) require.NoError(t, err, "libschema NewDatabase") diff --git a/lspostgres/postgres.go b/lspostgres/postgres.go index 87cc94d..cd571cd 100644 --- a/lspostgres/postgres.go +++ b/lspostgres/postgres.go @@ -330,25 +330,35 @@ func (p *Postgres) DoOneMigration(ctx context.Context, log *internal.Log, d *lib // CreateSchemaTableIfNotExists creates the migration tracking table for libschema. // It is expected to be called by libschema. func (p *Postgres) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Log, d *libschema.Database) error { - schema, tableName, err := trackingSchemaTable(d) + schemaName, schema, tableName, err := trackingSchemaTable(d) if err != nil { return err } - for { - if schema != "" { - _, err := d.DB().ExecContext(ctx, fmt.Sprintf(` - CREATE SCHEMA IF NOT EXISTS %s - `, schema)) - if err != nil { + if schemaName != "" { + var schemaExists bool + err = d.DB().QueryRowContext(ctx, + `SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = $1)`, + schemaName, + ).Scan(&schemaExists) + if err != nil { + return errors.Wrapf(err, "could not check if libschema schema '%s' exists", schemaName) + } + if !schemaExists { + for { + _, err := d.DB().ExecContext(ctx, fmt.Sprintf(` + CREATE SCHEMA IF NOT EXISTS %s + `, schema)) + if err == nil { + break + } if strings.Contains(err.Error(), `pq: duplicate key value violates unique constraint "pg_namespace_nspname_index"`) { p.log.Warn("Ignoring create schema collision with another transaction and trying again") time.Sleep(time.Second) continue } - return errors.Wrapf(err, "could not create libschema schema '%s'", schema) + return errors.Wrapf(err, "could not create libschema schema '%s'", schemaName) } } - break } for { _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` @@ -374,25 +384,26 @@ func (p *Postgres) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal return nil } -func trackingSchemaTable(d *libschema.Database) (string, string, error) { +func trackingSchemaTable(d *libschema.Database) (schemaName string, quotedSchema string, quotedTable string, err error) { tableName := d.Options.TrackingTable s := strings.Split(tableName, ".") switch len(s) { case 2: - schema := pq.QuoteIdentifier(s[0]) + schemaName = s[0] + quotedSchema = pq.QuoteIdentifier(schemaName) table := pq.QuoteIdentifier(s[1]) - return schema, schema + "." + table, nil + return schemaName, quotedSchema, quotedSchema + "." + table, nil case 1: - return "", pq.QuoteIdentifier(tableName), nil + return "", "", pq.QuoteIdentifier(tableName), nil default: - return "", "", errors.Errorf("tracking table '%s' is not valid", tableName) + return "", "", "", errors.Errorf("tracking table '%s' is not valid", tableName) } } // trackingTable returns the schema+table reference for the migration tracking table. // The name is already quoted properly for use as a save postgres identifier. func trackingTable(d *libschema.Database) string { - _, table, _ := trackingSchemaTable(d) + _, _, table, _ := trackingSchemaTable(d) return table } diff --git a/lspostgres/schema_override_test.go b/lspostgres/schema_override_test.go index 90d2294..a2b02f8 100644 --- a/lspostgres/schema_override_test.go +++ b/lspostgres/schema_override_test.go @@ -31,6 +31,8 @@ func TestSchemaOverrideTransactional(t *testing.T) { t.Cleanup(func() { cleanup(db) }) ctx := context.Background() + _, err = db.Exec("CREATE SCHEMA " + opts.SchemaOverride) + require.NoError(t, err) s := libschema.New(ctx, opts) log := libschema.LogFromLog(t) dbase, err := lspostgres.New(log, "schema_override_tx", s, db) diff --git a/lssinglestore/repeat_test.go b/lssinglestore/repeat_test.go index 3497bce..4c0fd49 100644 --- a/lssinglestore/repeat_test.go +++ b/lssinglestore/repeat_test.go @@ -28,6 +28,9 @@ func TestRepeat(t *testing.T) { defer func() { cleanup(db) }() options.DebugLogging = true + _, err = db.Exec("CREATE DATABASE IF NOT EXISTS " + options.SchemaOverride + " PARTITIONS 2") + require.NoError(t, err) + s := libschema.New(context.Background(), options) log := libschema.LogFromLog(t) dbase, _, err := lssinglestore.New(log, "test", s, db) diff --git a/lssinglestore/singlestore.go b/lssinglestore/singlestore.go index 1f2e272..91f4918 100644 --- a/lssinglestore/singlestore.go +++ b/lssinglestore/singlestore.go @@ -206,12 +206,20 @@ func makeID(raw string) (string, error) { } } +func normalizeSchemaName(raw string) string { + if len(raw) >= 2 && strings.HasPrefix(raw, "`") && strings.HasSuffix(raw, "`") { + return raw[1 : len(raw)-1] + } + return raw +} + func trackingSchemaTable(d *libschema.Database) (string, string, string, error) { tableName := d.Options.TrackingTable s := strings.Split(tableName, ".") switch len(s) { case 2: - schema, err := makeID(s[0]) + schema := s[0] + quotedSchema, err := makeID(schema) if err != nil { return "", "", "", errors.Wrap(err, "cannot make tracking table schema name") } @@ -219,7 +227,7 @@ func trackingSchemaTable(d *libschema.Database) (string, string, string, error) if err != nil { return "", "", "", errors.Wrap(err, "cannot make tracking table table name") } - return schema, schema + "." + table, table, nil + return schema, quotedSchema + "." + table, table, nil case 1: table, err := makeID(tableName) if err != nil { @@ -238,11 +246,26 @@ func (p *SingleStore) CreateSchemaTableIfNotExists(ctx context.Context, _ *inter return err } if schema != "" { - _, err := d.DB().ExecContext(ctx, fmt.Sprintf(` - CREATE DATABASE IF NOT EXISTS %s PARTITIONS 2 - `, schema)) + schemaName := normalizeSchemaName(schema) + var schemaExists bool + err = d.DB().QueryRowContext(ctx, + `SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)`, + schemaName, + ).Scan(&schemaExists) if err != nil { - return errors.Wrapf(err, "could not create libschema schema '%s'", schema) + return errors.Wrapf(err, "could not check if libschema schema '%s' exists", schemaName) + } + if !schemaExists { + quotedSchema, err := makeID(schema) + if err != nil { + return errors.Wrap(err, "cannot make tracking table schema name") + } + _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` + CREATE DATABASE IF NOT EXISTS %s PARTITIONS 2 + `, quotedSchema)) + if err != nil { + return errors.Wrapf(err, "could not create libschema schema '%s'", schemaName) + } } } _, err = d.DB().ExecContext(ctx, fmt.Sprintf(`