pase/store/migrate.go
2026-05-04 18:39:25 +00:00

157 lines
3.7 KiB
Go

package store
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"regexp"
"sort"
"strconv"
"strings"
"time"
)
type Migration struct {
Version int
Name string
SQL string
}
var migrationFilenameRE = regexp.MustCompile(`^(\d+)_([a-z0-9_]+)\.sql$`)
// LoadMigrations reads SQL files from an embedded FS.
// Files must be named like "001_initial.sql".
func LoadMigrations(fsys embed.FS, dir string) ([]Migration, error) {
entries, err := fs.ReadDir(fsys, dir)
if err != nil {
return nil, fmt.Errorf("read migration dir: %w", err)
}
var migrations []Migration
for _, e := range entries {
if e.IsDir() {
continue
}
m := migrationFilenameRE.FindStringSubmatch(e.Name())
if m == nil {
return nil, fmt.Errorf("invalid migration filename: %s", e.Name())
}
version, err := strconv.Atoi(m[1])
if err != nil {
return nil, fmt.Errorf("invalid version in %s: %w", e.Name(), err)
}
content, err := fs.ReadFile(fsys, dir+"/"+e.Name())
if err != nil {
return nil, fmt.Errorf("read %s: %w", e.Name(), err)
}
migrations = append(migrations, Migration{
Version: version,
Name: m[2],
SQL: string(content),
})
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
// Sanity check: versions must be unique and contiguous starting at 1.
for i, m := range migrations {
if m.Version != i+1 {
return nil, fmt.Errorf("migration version gap: expected %d, got %d (%s)",
i+1, m.Version, m.Name)
}
}
return migrations, nil
}
// Migrator applies migrations to a database.
type Migrator struct {
DB *sql.DB
Dialect Dialect
Migrations []Migration
// CreateTableSQL is the dialect-specific DDL for the schema_migrations
// table. Postgres and SQLite need slightly different syntax.
CreateTableSQL string
}
// Migrate applies all pending migrations.
func (m *Migrator) Migrate(ctx context.Context) error {
if _, err := m.DB.ExecContext(ctx, m.CreateTableSQL); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
applied, err := m.appliedVersions(ctx)
if err != nil {
return err
}
for _, mig := range m.Migrations {
if applied[mig.Version] {
continue
}
if err := m.apply(ctx, mig); err != nil {
return fmt.Errorf("apply migration %d (%s): %w", mig.Version, mig.Name, err)
}
}
return nil
}
func (m *Migrator) appliedVersions(ctx context.Context) (map[int]bool, error) {
rows, err := m.DB.QueryContext(ctx, `SELECT version FROM pase_schema_migrations`)
if err != nil {
return nil, fmt.Errorf("read schema_migrations: %w", err)
}
defer rows.Close()
applied := make(map[int]bool)
for rows.Next() {
var v int
if err := rows.Scan(&v); err != nil {
return nil, err
}
applied[v] = true
}
return applied, rows.Err()
}
func (m *Migrator) apply(ctx context.Context, mig Migration) error {
tx, err := m.DB.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
// Allow multi-statement migrations.
for _, stmt := range splitStatements(mig.SQL) {
if strings.TrimSpace(stmt) == "" {
continue
}
if _, err := tx.ExecContext(ctx, stmt); err != nil {
return fmt.Errorf("exec statement: %w\n--SQL--\n%s", err, stmt)
}
}
_, err = tx.ExecContext(ctx,
m.Dialect.Rebind(`INSERT INTO pase_schema_migrations (version, name, applied_at)
VALUES (?, ?, ?)`),
mig.Version, mig.Name, time.Now().UTC())
if err != nil {
return fmt.Errorf("record migration: %w", err)
}
return tx.Commit()
}
// splitStatements is a naive splitter on semicolons. It works for our
// migrations because we control them. Don't use this on user-supplied SQL.
func splitStatements(sql string) []string {
return strings.Split(sql, ";")
}