157 lines
3.7 KiB
Go
157 lines
3.7 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Migration struct {
|
|
Version int
|
|
Name string
|
|
SQL string
|
|
}
|
|
|
|
var migrationFilenameRE = regexp.MustCompile(`^(\d+)_([a-z0-9_]+)\.sql$`)
|
|
|
|
// LoadMigrations reads SQL files from an embedded FS.
|
|
// Files must be named like "001_initial.sql".
|
|
func LoadMigrations(fsys embed.FS, dir string) ([]Migration, error) {
|
|
entries, err := fs.ReadDir(fsys, dir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read migration dir: %w", err)
|
|
}
|
|
|
|
var migrations []Migration
|
|
for _, e := range entries {
|
|
if e.IsDir() {
|
|
continue
|
|
}
|
|
m := migrationFilenameRE.FindStringSubmatch(e.Name())
|
|
if m == nil {
|
|
return nil, fmt.Errorf("invalid migration filename: %s", e.Name())
|
|
}
|
|
version, err := strconv.Atoi(m[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid version in %s: %w", e.Name(), err)
|
|
}
|
|
|
|
content, err := fs.ReadFile(fsys, dir+"/"+e.Name())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read %s: %w", e.Name(), err)
|
|
}
|
|
|
|
migrations = append(migrations, Migration{
|
|
Version: version,
|
|
Name: m[2],
|
|
SQL: string(content),
|
|
})
|
|
}
|
|
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
|
|
// Sanity check: versions must be unique and contiguous starting at 1.
|
|
for i, m := range migrations {
|
|
if m.Version != i+1 {
|
|
return nil, fmt.Errorf("migration version gap: expected %d, got %d (%s)",
|
|
i+1, m.Version, m.Name)
|
|
}
|
|
}
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// Migrator applies migrations to a database.
|
|
type Migrator struct {
|
|
DB *sql.DB
|
|
Dialect Dialect
|
|
Migrations []Migration
|
|
|
|
// CreateTableSQL is the dialect-specific DDL for the schema_migrations
|
|
// table. Postgres and SQLite need slightly different syntax.
|
|
CreateTableSQL string
|
|
}
|
|
|
|
// Migrate applies all pending migrations.
|
|
func (m *Migrator) Migrate(ctx context.Context) error {
|
|
if _, err := m.DB.ExecContext(ctx, m.CreateTableSQL); err != nil {
|
|
return fmt.Errorf("create schema_migrations: %w", err)
|
|
}
|
|
|
|
applied, err := m.appliedVersions(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, mig := range m.Migrations {
|
|
if applied[mig.Version] {
|
|
continue
|
|
}
|
|
if err := m.apply(ctx, mig); err != nil {
|
|
return fmt.Errorf("apply migration %d (%s): %w", mig.Version, mig.Name, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) appliedVersions(ctx context.Context) (map[int]bool, error) {
|
|
rows, err := m.DB.QueryContext(ctx, `SELECT version FROM pase_schema_migrations`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read schema_migrations: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := make(map[int]bool)
|
|
for rows.Next() {
|
|
var v int
|
|
if err := rows.Scan(&v); err != nil {
|
|
return nil, err
|
|
}
|
|
applied[v] = true
|
|
}
|
|
return applied, rows.Err()
|
|
}
|
|
|
|
func (m *Migrator) apply(ctx context.Context, mig Migration) error {
|
|
tx, err := m.DB.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Allow multi-statement migrations.
|
|
for _, stmt := range splitStatements(mig.SQL) {
|
|
if strings.TrimSpace(stmt) == "" {
|
|
continue
|
|
}
|
|
if _, err := tx.ExecContext(ctx, stmt); err != nil {
|
|
return fmt.Errorf("exec statement: %w\n--SQL--\n%s", err, stmt)
|
|
}
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx,
|
|
m.Dialect.Rebind(`INSERT INTO pase_schema_migrations (version, name, applied_at)
|
|
VALUES (?, ?, ?)`),
|
|
mig.Version, mig.Name, time.Now().UTC())
|
|
if err != nil {
|
|
return fmt.Errorf("record migration: %w", err)
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// splitStatements is a naive splitter on semicolons. It works for our
|
|
// migrations because we control them. Don't use this on user-supplied SQL.
|
|
func splitStatements(sql string) []string {
|
|
return strings.Split(sql, ";")
|
|
}
|