Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,69 @@ scima up --driver hana --dsn "hdb://user:pass@host:30015" --migrations-dir ./mig
scima status --driver hana --dsn "hdb://user:pass@host:30015" --migrations-dir ./migrations

# Postgres example
scima init --driver postgres --dsn "postgres://user:pass@localhost:5432/mydb?sslmode=disable" --migrations-dir ./migrations
scima up --driver postgres --dsn "postgres://user:pass@localhost:5432/mydb?sslmode=disable" --migrations-dir ./migrations
scima status --driver postgres --dsn "postgres://user:pass@localhost:5432/mydb?sslmode=disable" --migrations-dir ./migrations
```

## Migration files
Create paired files for each version:
```
0010_create_users_table.up.sql
0010_create_users_table.down.sql
```

You can write portable migrations using schema placeholders that are substituted at runtime:
| Placeholder | Description |
|-----------------|-------------|
| `{{schema}}` | Required schema name (error if `--schema` not provided) |
| `{{schema?}}` | Optional schema prefix: becomes `schema.` when provided, otherwise empty |
| `\\{{schema}}`, `\\{{schema?}}` | Escape sequence: leaves token literal (no substitution) |

Examples:

```sql


## How schema is used

The `--schema` flag serves two purposes:

1. **Migration tracking table**: The schema is used to qualify the migration bookkeeping table (e.g., `schema_migrations` becomes `<schema>.schema_migrations`).
2. **SQL placeholders**: Any migration SQL file containing the placeholders `{{schema}}` or `{{schema?}}` will have these tokens replaced with the provided schema value (or omitted if not set and using the optional form).

This allows you to:
- Track migrations in a schema-specific table
- Write portable migration SQL that adapts to different schemas without duplicating files

**Escaping placeholders:**
To prevent substitution and keep the literal token in your SQL, prefix the placeholder with a single backslash (e.g., `\{{schema}}`). This is a literal backslash in your SQL file, not Go string escaping.
-- Uses required schema placeholder
CREATE TABLE {{schema}}.users (
id BIGSERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL
);

-- Optional: will create table in default schema if none supplied
CREATE TABLE {{schema?}}audit_log (
id BIGSERIAL PRIMARY KEY,
event TEXT NOT NULL
);

-- Escaped tokens remain untouched
-- This will literally create table named {{schema}}.raw_data (assuming dialect allows curly braces)
CREATE TABLE \{{schema}}.raw_data(id INT);
-- Optional escaped
CREATE TABLE \{{schema?}}metrics(id INT);
```

Run with a schema:

```bash
scima up --driver postgres --dsn "$PG_DSN" --schema tenant_a --migrations-dir ./migrations
```

If `{{schema}}` appears and `--schema` is omitted, the command errors.


## Future roadmap
### Near-term enhancements
- Dialect-specific migrations: for portability you can keep separate directories (e.g. `migrations_pg/`) when syntax differs (Postgres vs HANA column add syntax). The CLI currently points to one directory; run with `--migrations-dir` per dialect.
Expand Down
6 changes: 4 additions & 2 deletions cmd/scima/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ var rootCmd = &cobra.Command{Use: "scima", Short: "Schema migrations for multipl
var driver string
var dsn string
var migrationsDir string
var schema string // optional schema qualification

func addGlobalFlags(cmd *cobra.Command) {
cmd.PersistentFlags().StringVar(&driver, "driver", "hana", "Database driver/dialect (hana, pg, mysql, sqlite, etc.)")
cmd.PersistentFlags().StringVar(&dsn, "dsn", "", "Database DSN / connection string")
cmd.PersistentFlags().StringVar(&migrationsDir, "migrations-dir", "./migrations", "Directory containing migration files")
cmd.PersistentFlags().StringVar(&schema, "schema", "", "Optional database schema for migration tracking table and SQL placeholders ({{schema}}, {{schema?}})")
}

func init() {
Expand All @@ -48,7 +50,7 @@ var initCmd = &cobra.Command{Use: "init", Short: "Initialize migration tracking
fmt.Fprintf(os.Stderr, "error closing db: %v\n", err)
}
}()
if err := migr.Dialect.EnsureMigrationTable(context.Background(), migr.Conn); err != nil {
if err := migr.EnsureMigrationTable(context.Background()); err != nil {
return err
}
fmt.Println("migration table ensured")
Expand Down Expand Up @@ -164,7 +166,7 @@ func buildMigrator(cfg config.Config) (*migrate.Migrator, *sql.DB, error) {
if err != nil {
return nil, nil, err
}
return migrate.NewMigrator(dial, dialect.SQLConn{DB: db}), db, nil
return migrate.NewMigrator(dial, dialect.SQLConn{DB: db}, schema), db, nil
}

func driverNameFor(driver string) string {
Expand Down
19 changes: 15 additions & 4 deletions internal/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"fmt"
)

const (
migrationTable = "SCIMA_SCHEMA_MIGRATIONS"
)

// Conn abstracts minimal operations needed for migration execution.
// Each dialect can wrap a DB connection or tx.
// Exec should execute statements separated already (no multi-statement parsing here).
Expand All @@ -30,10 +34,10 @@ type Rows interface {
// Dialect binds SQL variants and introspection / DDL helpers.
type Dialect interface {
Name() string
EnsureMigrationTable(ctx context.Context, c Conn) error
SelectAppliedVersions(ctx context.Context, c Conn) (map[int64]bool, error)
InsertVersion(ctx context.Context, c Conn, version int64) error
DeleteVersion(ctx context.Context, c Conn, version int64) error
EnsureMigrationTable(ctx context.Context, c Conn, schema string) error
SelectAppliedVersions(ctx context.Context, c Conn, schema string) (map[int64]bool, error)
InsertVersion(ctx context.Context, c Conn, schema string, version int64) error
DeleteVersion(ctx context.Context, c Conn, schema string, version int64) error
}

var registry = map[string]Dialect{}
Expand All @@ -49,3 +53,10 @@ func Get(name string) (Dialect, error) {
}
return d, nil
}

func qualifiedMigrationTable(schema string) string {
if schema == "" {
return migrationTable
}
return fmt.Sprintf("\"%s\".%s", schema, migrationTable)
}
26 changes: 14 additions & 12 deletions internal/dialect/hana.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,17 @@ func (h HanaDialect) Name() string { return "hana" }

func init() { Register(HanaDialect{}) }

const hanaMigrationsTable = "SCIMA_SCHEMA_MIGRATIONS" // uppercase by convention in HANA

// EnsureMigrationTable creates the migration tracking table if it does not exist.
func (h HanaDialect) EnsureMigrationTable(ctx context.Context, c Conn) error {
func (h HanaDialect) EnsureMigrationTable(ctx context.Context, c Conn, schema string) error {
// Try create table if not exists. HANA before 2.0 lacks standard IF NOT EXISTS for some DDL; we attempt and ignore errors.
create := fmt.Sprintf("CREATE TABLE %s (version BIGINT PRIMARY KEY)", hanaMigrationsTable)
table := qualifiedMigrationTable(schema)
create := fmt.Sprintf("CREATE TABLE %s (version BIGINT PRIMARY KEY)", table)
if _, err := c.ExecContext(ctx, create); err != nil {
// Ignore 'already exists' like sqlstate 301? We do a simple substring match.
if !containsIgnoreCase(err.Error(), "exists") {
// Could attempt a SELECT to verify existence.
// Fallback: check selectable.
rows, qerr := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s WHERE 1=0", hanaMigrationsTable))
rows, qerr := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s WHERE 1=0", table))
if qerr != nil {
return fmt.Errorf("ensure migrations table failed: %v createErr: %v", qerr, err)
}
Expand Down Expand Up @@ -70,11 +69,12 @@ func stringIndexFold(hay, needle string) int {
}

// SelectAppliedVersions returns a map of applied migration versions from the tracking table.
func (h HanaDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map[int64]bool, error) {
rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", hanaMigrationsTable))
func (h HanaDialect) SelectAppliedVersions(ctx context.Context, c Conn, schema string) (map[int64]bool, error) {
table := qualifiedMigrationTable(schema)
rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", table))
if err != nil {
// If table not existing treat as empty; attempt create then return empty.
if cerr := h.EnsureMigrationTable(ctx, c); cerr != nil {
if cerr := h.EnsureMigrationTable(ctx, c, schema); cerr != nil {
return nil, err
}
return map[int64]bool{}, nil
Expand All @@ -99,13 +99,15 @@ func (h HanaDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map[int
}

// InsertVersion inserts a migration version into the HANA migrations table.
func (h HanaDialect) InsertVersion(ctx context.Context, c Conn, version int64) error {
_, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", hanaMigrationsTable), version)
func (h HanaDialect) InsertVersion(ctx context.Context, c Conn, schema string, version int64) error {
table := qualifiedMigrationTable(schema)
_, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", table), version)
return err
}

// DeleteVersion deletes a migration version from the HANA migrations table.
func (h HanaDialect) DeleteVersion(ctx context.Context, c Conn, version int64) error {
_, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = ?", hanaMigrationsTable), version)
func (h HanaDialect) DeleteVersion(ctx context.Context, c Conn, schema string, version int64) error {
table := qualifiedMigrationTable(schema)
_, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = ?", table), version)
return err
}
22 changes: 12 additions & 10 deletions internal/dialect/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ type PostgresDialect struct{}
// Name returns the name of the dialect ("postgres").
func (p PostgresDialect) Name() string { return "postgres" }

const pgMigrationsTable = "schema_migrations"

// EnsureMigrationTable creates the migration tracking table if it does not exist.
func (p PostgresDialect) EnsureMigrationTable(ctx context.Context, c Conn) error {
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version BIGINT PRIMARY KEY)", pgMigrationsTable)
func (p PostgresDialect) EnsureMigrationTable(ctx context.Context, c Conn, schema string) error {
table := qualifiedMigrationTable(schema)
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version BIGINT PRIMARY KEY)", table)
_, err := c.ExecContext(ctx, stmt)
return err
}

// SelectAppliedVersions returns a map of applied migration versions from the tracking table.
func (p PostgresDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map[int64]bool, error) {
rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", pgMigrationsTable))
func (p PostgresDialect) SelectAppliedVersions(ctx context.Context, c Conn, schema string) (map[int64]bool, error) {
table := qualifiedMigrationTable(schema)
rows, err := c.QueryContext(ctx, fmt.Sprintf("SELECT version FROM %s", table))
if err != nil {
return nil, err
}
Expand All @@ -47,14 +47,16 @@ func (p PostgresDialect) SelectAppliedVersions(ctx context.Context, c Conn) (map
}

// InsertVersion inserts a migration version into the Postgres migrations table.
func (p PostgresDialect) InsertVersion(ctx context.Context, c Conn, version int64) error {
_, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", pgMigrationsTable), version)
func (p PostgresDialect) InsertVersion(ctx context.Context, c Conn, schema string, version int64) error {
table := qualifiedMigrationTable(schema)
_, err := c.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", table), version)
return err
}

// DeleteVersion deletes a migration version from the Postgres migrations table.
func (p PostgresDialect) DeleteVersion(ctx context.Context, c Conn, version int64) error {
_, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = $1", pgMigrationsTable), version)
func (p PostgresDialect) DeleteVersion(ctx context.Context, c Conn, schema string, version int64) error {
table := qualifiedMigrationTable(schema)
_, err := c.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s WHERE version = $1", table), version)
return err
}

Expand Down
83 changes: 76 additions & 7 deletions internal/migrate/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package migrate
import (
"context"
"fmt"
"strings"

"github.com/scima/scima/internal/dialect"
)
Expand All @@ -12,26 +13,38 @@ import (
type Migrator struct {
Conn dialect.Conn
Dialect dialect.Dialect
Schema string // optional schema qualifier
}

// NewMigrator creates a new Migrator for the given dialect and connection.
func NewMigrator(d dialect.Dialect, c dialect.Conn) *Migrator { return &Migrator{Conn: c, Dialect: d} }
func NewMigrator(d dialect.Dialect, c dialect.Conn, schema string) *Migrator {
return &Migrator{Conn: c, Dialect: d, Schema: schema}
}

// EnsureMigrationTable ensures the migration tracking table exists.
func (m *Migrator) EnsureMigrationTable(ctx context.Context) error {
return m.Dialect.EnsureMigrationTable(ctx, m.Conn, m.Schema)
}

// Status returns applied version set.
func (m *Migrator) Status(ctx context.Context) (map[int64]bool, error) {
if err := m.Dialect.EnsureMigrationTable(ctx, m.Conn); err != nil {
if err := m.Dialect.EnsureMigrationTable(ctx, m.Conn, m.Schema); err != nil {
return nil, err
}
return m.Dialect.SelectAppliedVersions(ctx, m.Conn)
return m.Dialect.SelectAppliedVersions(ctx, m.Conn, m.Schema)
}

// ApplyUp applies pending up migrations.
func (m *Migrator) ApplyUp(ctx context.Context, ups []MigrationFile) error {
for _, up := range ups {
if _, err := m.Conn.ExecContext(ctx, up.SQL); err != nil {
expanded, err := expandPlaceholders(up.SQL, m.Schema)
if err != nil {
return fmt.Errorf("placeholder expansion up %d: %w", up.Version, err)
}
if _, err := m.Conn.ExecContext(ctx, expanded); err != nil {
return fmt.Errorf("apply up %d failed: %w", up.Version, err)
}
if err := m.Dialect.InsertVersion(ctx, m.Conn, up.Version); err != nil {
if err := m.Dialect.InsertVersion(ctx, m.Conn, m.Schema, up.Version); err != nil {
return err
}
}
Expand All @@ -41,12 +54,68 @@ func (m *Migrator) ApplyUp(ctx context.Context, ups []MigrationFile) error {
// ApplyDown applies downs.
func (m *Migrator) ApplyDown(ctx context.Context, downs []MigrationFile) error {
for _, down := range downs {
if _, err := m.Conn.ExecContext(ctx, down.SQL); err != nil {
expanded, err := expandPlaceholders(down.SQL, m.Schema)
if err != nil {
return fmt.Errorf("placeholder expansion down %d: %w", down.Version, err)
}
if _, err := m.Conn.ExecContext(ctx, expanded); err != nil {
return fmt.Errorf("apply down %d failed: %w", down.Version, err)
}
if err := m.Dialect.DeleteVersion(ctx, m.Conn, down.Version); err != nil {
if err := m.Dialect.DeleteVersion(ctx, m.Conn, m.Schema, down.Version); err != nil {
return err
}
}
return nil
}

const (
requiredSchemaToken = "{{schema}}"
optionalSchemaToken = "{{schema?}}"
)

// expandPlaceholders substitutes schema tokens in SQL.
// {{schema}} requires a non-empty schema; {{schema?}} inserts schema plus dot or nothing.
func expandPlaceholders(sql string, schema string) (string, error) {
// We process with a manual scan to support escaping via backslash: \{{schema}} or \{{schema?}} remain literal.
// Strategy: iterate runes, detect backslash before placeholder start; build output.
var b strings.Builder
// Precompute for efficiency.
req := requiredSchemaToken
opt := optionalSchemaToken
for i := 0; i < len(sql); {
// Handle escaped required placeholder
if i+1+len(req) <= len(sql) && sql[i] == '\\' && sql[i+1:i+1+len(req)] == req {
b.WriteString(req) // drop escape, keep literal token
i += 1 + len(req)
continue
}
// Handle escaped optional placeholder
if i+1+len(opt) <= len(sql) && sql[i] == '\\' && sql[i+1:i+1+len(opt)] == opt {
b.WriteString(opt)
i += 1 + len(opt)
continue
}
// Unescaped optional
if i+len(opt) <= len(sql) && sql[i:i+len(opt)] == opt {
if schema != "" {
b.WriteString(schema)
b.WriteByte('.')
}
i += len(opt)
continue
}
// Unescaped required
if i+len(req) <= len(sql) && sql[i:i+len(req)] == req {
if schema == "" {
return "", fmt.Errorf("%s used but schema not set", requiredSchemaToken)
}
b.WriteString(schema)
i += len(req)
continue
}
// Default: copy one byte
b.WriteByte(sql[i])
i++
}
return b.String(), nil
}
Loading