diff --git a/README.md b/README.md index 395eb5e..16f17fc 100644 --- a/README.md +++ b/README.md @@ -18,18 +18,69 @@ scima up --driver hana --dsn "hdb://user:pass@host:30015" --migrations-dir ./mig scima status --driver hana --dsn "hdb://user:pass@host:30015" --migrations-dir ./migrations # Postgres example -scima init --driver postgres --dsn "postgres://user:pass@localhost:5432/mydb?sslmode=disable" --migrations-dir ./migrations scima up --driver postgres --dsn "postgres://user:pass@localhost:5432/mydb?sslmode=disable" --migrations-dir ./migrations scima status --driver postgres --dsn "postgres://user:pass@localhost:5432/mydb?sslmode=disable" --migrations-dir ./migrations ``` ## Migration files -Create paired files for each version: ``` 0010_create_users_table.up.sql 0010_create_users_table.down.sql ``` +You can write portable migrations using schema placeholders that are substituted at runtime: +| Placeholder | Description | +|-----------------|-------------| +| `{{schema}}` | Required schema name (error if `--schema` not provided) | +| `{{schema?}}` | Optional schema prefix: becomes `schema.` when provided, otherwise empty | +| `\\{{schema}}`, `\\{{schema?}}` | Escape sequence: leaves token literal (no substitution) | + +Examples: + +```sql + + +## How schema is used + +The `--schema` flag serves two purposes: + +1. **Migration tracking table**: The schema is used to qualify the migration bookkeeping table (e.g., `schema_migrations` becomes `.schema_migrations`). +2. **SQL placeholders**: Any migration SQL file containing the placeholders `{{schema}}` or `{{schema?}}` will have these tokens replaced with the provided schema value (or omitted if not set and using the optional form). + +This allows you to: +- Track migrations in a schema-specific table +- Write portable migration SQL that adapts to different schemas without duplicating files + +**Escaping placeholders:** +To prevent substitution and keep the literal token in your SQL, prefix the placeholder with a single backslash (e.g., `\{{schema}}`). This is a literal backslash in your SQL file, not Go string escaping. +-- Uses required schema placeholder +CREATE TABLE {{schema}}.users ( + id BIGSERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL +); + +-- Optional: will create table in default schema if none supplied +CREATE TABLE {{schema?}}audit_log ( + id BIGSERIAL PRIMARY KEY, + event TEXT NOT NULL +); + +-- Escaped tokens remain untouched +-- This will literally create table named {{schema}}.raw_data (assuming dialect allows curly braces) +CREATE TABLE \{{schema}}.raw_data(id INT); +-- Optional escaped +CREATE TABLE \{{schema?}}metrics(id INT); +``` + +Run with a schema: + +```bash +scima up --driver postgres --dsn "$PG_DSN" --schema tenant_a --migrations-dir ./migrations +``` + +If `{{schema}}` appears and `--schema` is omitted, the command errors. + + ## Future roadmap ### Near-term enhancements - Dialect-specific migrations: for portability you can keep separate directories (e.g. `migrations_pg/`) when syntax differs (Postgres vs HANA column add syntax). The CLI currently points to one directory; run with `--migrations-dir` per dialect. diff --git a/cmd/scima/main.go b/cmd/scima/main.go index b98ab6c..71e07ac 100644 --- a/cmd/scima/main.go +++ b/cmd/scima/main.go @@ -20,11 +20,13 @@ var rootCmd = &cobra.Command{Use: "scima", Short: "Schema migrations for multipl var driver string var dsn string var migrationsDir string +var schema string // optional schema qualification func addGlobalFlags(cmd *cobra.Command) { cmd.PersistentFlags().StringVar(&driver, "driver", "hana", "Database driver/dialect (hana, pg, mysql, sqlite, etc.)") cmd.PersistentFlags().StringVar(&dsn, "dsn", "", "Database DSN / connection string") cmd.PersistentFlags().StringVar(&migrationsDir, "migrations-dir", "./migrations", "Directory containing migration files") + cmd.PersistentFlags().StringVar(&schema, "schema", "", "Optional database schema for migration tracking table and SQL placeholders ({{schema}}, {{schema?}})") } func init() { @@ -48,7 +50,7 @@ var initCmd = &cobra.Command{Use: "init", Short: "Initialize migration tracking fmt.Fprintf(os.Stderr, "error closing db: %v\n", err) } }() - if err := migr.Dialect.EnsureMigrationTable(context.Background(), migr.Conn); err != nil { + if err := migr.EnsureMigrationTable(context.Background()); err != nil { return err } fmt.Println("migration table ensured") @@ -164,7 +166,7 @@ func buildMigrator(cfg config.Config) (*migrate.Migrator, *sql.DB, error) { if err != nil { return nil, nil, err } - return migrate.NewMigrator(dial, dialect.SQLConn{DB: db}), db, nil + return migrate.NewMigrator(dial, dialect.SQLConn{DB: db}, schema), db, nil } func driverNameFor(driver string) string { diff --git a/internal/dialect/dialect.go b/internal/dialect/dialect.go index dcb6c9f..54c72a9 100644 --- a/internal/dialect/dialect.go +++ b/internal/dialect/dialect.go @@ -6,6 +6,10 @@ import ( "fmt" ) +const ( + migrationTable = "SCIMA_SCHEMA_MIGRATIONS" +) + // Conn abstracts minimal operations needed for migration execution. // Each dialect can wrap a DB connection or tx. // Exec should execute statements separated already (no multi-statement parsing here). @@ -30,10 +34,10 @@ type Rows interface { // Dialect binds SQL variants and introspection / DDL helpers. type Dialect interface { Name() string - EnsureMigrationTable(ctx context.Context, c Conn) error - SelectAppliedVersions(ctx context.Context, c Conn) (map[int64]bool, error) - InsertVersion(ctx context.Context, c Conn, version int64) error - DeleteVersion(ctx context.Context, c Conn, version int64) error + EnsureMigrationTable(ctx context.Context, c Conn, schema string) error + SelectAppliedVersions(ctx context.Context, c Conn, schema string) (map[int64]bool, error) + InsertVersion(ctx context.Context, c Conn, schema string, version int64) error + DeleteVersion(ctx context.Context, c Conn, schema string, version int64) error } var registry = map[string]Dialect{} @@ -49,3 +53,10 @@ func Get(name string) (Dialect, error) { } return d, nil } + +func qualifiedMigrationTable(schema string) string { + if schema == "" { + return migrationTable + } + return fmt.Sprintf("\"%s\".%s", schema, migrationTable) +} diff --git a/internal/dialect/hana.go b/internal/dialect/hana.go index 6c4457c..550a0ad 100644 --- a/internal/dialect/hana.go +++ b/internal/dialect/hana.go @@ -14,18 +14,17 @@ func (h HanaDialect) Name() string { return "hana" } func init() { Register(HanaDialect{}) } -const hanaMigrationsTable = "SCIMA_SCHEMA_MIGRATIONS" // uppercase by convention in HANA - // EnsureMigrationTable creates the migration tracking table if it does not exist. -func (h HanaDialect) EnsureMigrationTable(ctx context.Context, c Conn) error { +func (h HanaDialect) EnsureMigrationTable(ctx context.Context, c Conn, schema string) error { // Try create table if not exists. HANA before 2.0 lacks standard IF NOT EXISTS for some DDL; we attempt and ignore errors. - create := fmt.Sprintf("CREATE TABLE %s (version BIGINT PRIMARY KEY)", hanaMigrationsTable) + table := qualifiedMigrationTable(schema) + create := fmt.Sprintf("CREATE TABLE %s (version BIGINT PRIMARY KEY)", table) if _, err := c.ExecContext(ctx, create); err != nil { // Ignore 'already exists' like sqlstate 301? We do a simple substring match. if !containsIgnoreCase(err.Error(), "exists") { // Could attempt a SELECT to verify existence. // Fallback: check selectable. - rows, qerr := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s WHERE 1=0", hanaMigrationsTable)) + rows, qerr := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s WHERE 1=0", table)) if qerr != nil { return fmt.Errorf("ensure migrations table failed: %v createErr: %v", qerr, err) } @@ -70,11 +69,12 @@ func stringIndexFold(hay, needle string) int { } // SelectAppliedVersions returns a map of applied migration versions from the tracking table. -func (h HanaDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map[int64]bool, error) { - rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", hanaMigrationsTable)) +func (h HanaDialect) SelectAppliedVersions(ctx context.Context, c Conn, schema string) (map[int64]bool, error) { + table := qualifiedMigrationTable(schema) + rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", table)) if err != nil { // If table not existing treat as empty; attempt create then return empty. - if cerr := h.EnsureMigrationTable(ctx, c); cerr != nil { + if cerr := h.EnsureMigrationTable(ctx, c, schema); cerr != nil { return nil, err } return map[int64]bool{}, nil @@ -99,13 +99,15 @@ func (h HanaDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map[int } // InsertVersion inserts a migration version into the HANA migrations table. -func (h HanaDialect) InsertVersion(ctx context.Context, c Conn, version int64) error { - _, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", hanaMigrationsTable), version) +func (h HanaDialect) InsertVersion(ctx context.Context, c Conn, schema string, version int64) error { + table := qualifiedMigrationTable(schema) + _, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", table), version) return err } // DeleteVersion deletes a migration version from the HANA migrations table. -func (h HanaDialect) DeleteVersion(ctx context.Context, c Conn, version int64) error { - _, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = ?", hanaMigrationsTable), version) +func (h HanaDialect) DeleteVersion(ctx context.Context, c Conn, schema string, version int64) error { + table := qualifiedMigrationTable(schema) + _, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = ?", table), version) return err } diff --git a/internal/dialect/postgres.go b/internal/dialect/postgres.go index b2e5326..6d1f809 100644 --- a/internal/dialect/postgres.go +++ b/internal/dialect/postgres.go @@ -12,18 +12,18 @@ type PostgresDialect struct{} // Name returns the name of the dialect ("postgres"). func (p PostgresDialect) Name() string { return "postgres" } -const pgMigrationsTable = "schema_migrations" - // EnsureMigrationTable creates the migration tracking table if it does not exist. -func (p PostgresDialect) EnsureMigrationTable(ctx context.Context, c Conn) error { - stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version BIGINT PRIMARY KEY)", pgMigrationsTable) +func (p PostgresDialect) EnsureMigrationTable(ctx context.Context, c Conn, schema string) error { + table := qualifiedMigrationTable(schema) + stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version BIGINT PRIMARY KEY)", table) _, err := c.ExecContext(ctx, stmt) return err } // SelectAppliedVersions returns a map of applied migration versions from the tracking table. -func (p PostgresDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map[int64]bool, error) { - rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", pgMigrationsTable)) +func (p PostgresDialect) SelectAppliedVersions(ctx context.Context, c Conn, schema string) (map[int64]bool, error) { + table := qualifiedMigrationTable(schema) + rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", table)) if err != nil { return nil, err } @@ -47,14 +47,16 @@ func (p PostgresDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map } // InsertVersion inserts a migration version into the Postgres migrations table. -func (p PostgresDialect) InsertVersion(ctx context.Context, c Conn, version int64) error { - _, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", pgMigrationsTable), version) +func (p PostgresDialect) InsertVersion(ctx context.Context, c Conn, schema string, version int64) error { + table := qualifiedMigrationTable(schema) + _, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", table), version) return err } // DeleteVersion deletes a migration version from the Postgres migrations table. -func (p PostgresDialect) DeleteVersion(ctx context.Context, c Conn, version int64) error { - _, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = $1", pgMigrationsTable), version) +func (p PostgresDialect) DeleteVersion(ctx context.Context, c Conn, schema string, version int64) error { + table := qualifiedMigrationTable(schema) + _, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = $1", table), version) return err } diff --git a/internal/migrate/migrator.go b/internal/migrate/migrator.go index 722953a..19df7f3 100644 --- a/internal/migrate/migrator.go +++ b/internal/migrate/migrator.go @@ -4,6 +4,7 @@ package migrate import ( "context" "fmt" + "strings" "github.com/scima/scima/internal/dialect" ) @@ -12,26 +13,38 @@ import ( type Migrator struct { Conn dialect.Conn Dialect dialect.Dialect + Schema string // optional schema qualifier } // NewMigrator creates a new Migrator for the given dialect and connection. -func NewMigrator(d dialect.Dialect, c dialect.Conn) *Migrator { return &Migrator{Conn: c, Dialect: d} } +func NewMigrator(d dialect.Dialect, c dialect.Conn, schema string) *Migrator { + return &Migrator{Conn: c, Dialect: d, Schema: schema} +} + +// EnsureMigrationTable ensures the migration tracking table exists. +func (m *Migrator) EnsureMigrationTable(ctx context.Context) error { + return m.Dialect.EnsureMigrationTable(ctx, m.Conn, m.Schema) +} // Status returns applied version set. func (m *Migrator) Status(ctx context.Context) (map[int64]bool, error) { - if err := m.Dialect.EnsureMigrationTable(ctx, m.Conn); err != nil { + if err := m.Dialect.EnsureMigrationTable(ctx, m.Conn, m.Schema); err != nil { return nil, err } - return m.Dialect.SelectAppliedVersions(ctx, m.Conn) + return m.Dialect.SelectAppliedVersions(ctx, m.Conn, m.Schema) } // ApplyUp applies pending up migrations. func (m *Migrator) ApplyUp(ctx context.Context, ups []MigrationFile) error { for _, up := range ups { - if _, err := m.Conn.ExecContext(ctx, up.SQL); err != nil { + expanded, err := expandPlaceholders(up.SQL, m.Schema) + if err != nil { + return fmt.Errorf("placeholder expansion up %d: %w", up.Version, err) + } + if _, err := m.Conn.ExecContext(ctx, expanded); err != nil { return fmt.Errorf("apply up %d failed: %w", up.Version, err) } - if err := m.Dialect.InsertVersion(ctx, m.Conn, up.Version); err != nil { + if err := m.Dialect.InsertVersion(ctx, m.Conn, m.Schema, up.Version); err != nil { return err } } @@ -41,12 +54,68 @@ func (m *Migrator) ApplyUp(ctx context.Context, ups []MigrationFile) error { // ApplyDown applies downs. func (m *Migrator) ApplyDown(ctx context.Context, downs []MigrationFile) error { for _, down := range downs { - if _, err := m.Conn.ExecContext(ctx, down.SQL); err != nil { + expanded, err := expandPlaceholders(down.SQL, m.Schema) + if err != nil { + return fmt.Errorf("placeholder expansion down %d: %w", down.Version, err) + } + if _, err := m.Conn.ExecContext(ctx, expanded); err != nil { return fmt.Errorf("apply down %d failed: %w", down.Version, err) } - if err := m.Dialect.DeleteVersion(ctx, m.Conn, down.Version); err != nil { + if err := m.Dialect.DeleteVersion(ctx, m.Conn, m.Schema, down.Version); err != nil { return err } } return nil } + +const ( + requiredSchemaToken = "{{schema}}" + optionalSchemaToken = "{{schema?}}" +) + +// expandPlaceholders substitutes schema tokens in SQL. +// {{schema}} requires a non-empty schema; {{schema?}} inserts schema plus dot or nothing. +func expandPlaceholders(sql string, schema string) (string, error) { + // We process with a manual scan to support escaping via backslash: \{{schema}} or \{{schema?}} remain literal. + // Strategy: iterate runes, detect backslash before placeholder start; build output. + var b strings.Builder + // Precompute for efficiency. + req := requiredSchemaToken + opt := optionalSchemaToken + for i := 0; i < len(sql); { + // Handle escaped required placeholder + if i+1+len(req) <= len(sql) && sql[i] == '\\' && sql[i+1:i+1+len(req)] == req { + b.WriteString(req) // drop escape, keep literal token + i += 1 + len(req) + continue + } + // Handle escaped optional placeholder + if i+1+len(opt) <= len(sql) && sql[i] == '\\' && sql[i+1:i+1+len(opt)] == opt { + b.WriteString(opt) + i += 1 + len(opt) + continue + } + // Unescaped optional + if i+len(opt) <= len(sql) && sql[i:i+len(opt)] == opt { + if schema != "" { + b.WriteString(schema) + b.WriteByte('.') + } + i += len(opt) + continue + } + // Unescaped required + if i+len(req) <= len(sql) && sql[i:i+len(req)] == req { + if schema == "" { + return "", fmt.Errorf("%s used but schema not set", requiredSchemaToken) + } + b.WriteString(schema) + i += len(req) + continue + } + // Default: copy one byte + b.WriteByte(sql[i]) + i++ + } + return b.String(), nil +} diff --git a/internal/migrate/migrator_test.go b/internal/migrate/migrator_test.go index 170bab9..33cc40f 100644 --- a/internal/migrate/migrator_test.go +++ b/internal/migrate/migrator_test.go @@ -60,23 +60,25 @@ func (r mockRows) Err() error { return nil } type mockDialect struct{ versions map[int64]bool } -func (d mockDialect) Name() string { return "mock" } -func (d mockDialect) EnsureMigrationTable(_ context.Context, _ dialect.Conn) error { return nil } -func (d mockDialect) SelectAppliedVersions(_ context.Context, _ dialect.Conn) (map[int64]bool, error) { +func (d mockDialect) Name() string { return "mock" } +func (d mockDialect) EnsureMigrationTable(_ context.Context, _ dialect.Conn, _ string) error { + return nil +} +func (d mockDialect) SelectAppliedVersions(_ context.Context, _ dialect.Conn, _ string) (map[int64]bool, error) { return d.versions, nil } -func (d mockDialect) InsertVersion(_ context.Context, _ dialect.Conn, version int64) error { +func (d mockDialect) InsertVersion(_ context.Context, _ dialect.Conn, _ string, version int64) error { d.versions[version] = true return nil } -func (d mockDialect) DeleteVersion(_ context.Context, _ dialect.Conn, version int64) error { +func (d mockDialect) DeleteVersion(_ context.Context, _ dialect.Conn, _ string, version int64) error { delete(d.versions, version) return nil } func TestMigratorApplyUpDown(t *testing.T) { versions := map[int64]bool{10: true} - migr := NewMigrator(mockDialect{versions: versions}, &mockConn{Versions: versions}) + migr := NewMigrator(mockDialect{versions: versions}, &mockConn{Versions: versions}, "") ups := []MigrationFile{{Version: 20, Name: "add_col", Direction: "up", SQL: "ALTER"}} if err := migr.ApplyUp(context.Background(), ups); err != nil { t.Fatalf("apply up: %v", err) diff --git a/internal/migrate/placeholder_test.go b/internal/migrate/placeholder_test.go new file mode 100644 index 0000000..93402b0 --- /dev/null +++ b/internal/migrate/placeholder_test.go @@ -0,0 +1,85 @@ +package migrate + +import ( + "testing" +) + +func TestExpandPlaceholdersRequired(t *testing.T) { + in := "CREATE TABLE {{schema}}.users(id INT);" + out, err := expandPlaceholders(in, "tenant1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "CREATE TABLE tenant1.users(id INT);" { + t.Fatalf("mismatch: %s", out) + } +} + +func TestExpandPlaceholdersOptionalWithSchema(t *testing.T) { + in := "CREATE TABLE {{schema?}}users(id INT);" + out, err := expandPlaceholders(in, "tenant1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "CREATE TABLE tenant1.users(id INT);" { + t.Fatalf("mismatch: %s", out) + } +} + +func TestExpandPlaceholdersOptionalNoSchema(t *testing.T) { + in := "CREATE TABLE {{schema?}}users(id INT);" + out, err := expandPlaceholders(in, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "CREATE TABLE users(id INT);" { + t.Fatalf("mismatch: %s", out) + } +} + +func TestExpandPlaceholdersRequiredMissingSchema(t *testing.T) { + in := "CREATE TABLE {{schema}}.users(id INT);" + if _, err := expandPlaceholders(in, ""); err == nil { + t.Fatalf("expected error when schema empty with required placeholder") + } +} + +func TestExpandPlaceholdersMultiple(t *testing.T) { + in := "INSERT INTO {{schema}}.a SELECT * FROM {{schema}}.b;" + out, err := expandPlaceholders(in, "tenant1") + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if out != "INSERT INTO tenant1.a SELECT * FROM tenant1.b;" { + t.Fatalf("mismatch: %s", out) + } +} + +func TestExpandPlaceholdersEscapedRequired(t *testing.T) { + in := "CREATE TABLE \\{{schema}}.users(id INT);" // escaped token should remain literal + out, err := expandPlaceholders(in, "tenant1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "CREATE TABLE {{schema}}.users(id INT);" { + t.Fatalf("mismatch escaped required: %s", out) + } +} + +func TestExpandPlaceholdersEscapedOptional(t *testing.T) { + in := "CREATE TABLE \\{{schema?}}users(id INT);" + out, err := expandPlaceholders(in, "tenant1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "CREATE TABLE {{schema?}}users(id INT);" { + t.Fatalf("mismatch escaped optional with schema: %s", out) + } + out2, err := expandPlaceholders(in, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out2 != "CREATE TABLE {{schema?}}users(id INT);" { + t.Fatalf("mismatch escaped optional without schema: %s", out2) + } +} diff --git a/tests/integration/postgres/postgres_integration_test.go b/tests/integration/postgres/postgres_integration_test.go index 3e7a256..62dc88e 100644 --- a/tests/integration/postgres/postgres_integration_test.go +++ b/tests/integration/postgres/postgres_integration_test.go @@ -87,6 +87,11 @@ func TestPostgresMigrationsIntegration(t *testing.T) { if err := db.PingContext(ctx); err != nil { t.Fatalf("ping: %v", err) } + // Setup test schema + _, err = db.ExecContext(ctx, `CREATE SCHEMA IF NOT EXISTS test_schema`) + if err != nil { + t.Fatalf("create schema: %v", err) + } // --------------------------------------------------------------------- // MIGRATOR: Acquire dialect and create migrator wrapper @@ -95,7 +100,7 @@ func TestPostgresMigrationsIntegration(t *testing.T) { if err != nil { t.Fatalf("get dialect: %v", err) } - migr := migrate.NewMigrator(d, dialect.SQLConn{DB: db}) + migr := migrate.NewMigrator(d, dialect.SQLConn{DB: db}, "test_schema") // --------------------------------------------------------------------- // DISCOVERY: Locate migrations directory and parse & validate files