initial pase store definitions
This commit is contained in:
parent
fefe08f6f9
commit
b947535795
10 changed files with 723 additions and 0 deletions
157
store/migrate.go
Normal file
157
store/migrate.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
Version int
|
||||
Name string
|
||||
SQL string
|
||||
}
|
||||
|
||||
var migrationFilenameRE = regexp.MustCompile(`^(\d+)_([a-z0-9_]+)\.sql$`)
|
||||
|
||||
// LoadMigrations reads SQL files from an embedded FS.
|
||||
// Files must be named like "001_initial.sql".
|
||||
func LoadMigrations(fsys embed.FS, dir string) ([]Migration, error) {
|
||||
entries, err := fs.ReadDir(fsys, dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read migration dir: %w", err)
|
||||
}
|
||||
|
||||
var migrations []Migration
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
m := migrationFilenameRE.FindStringSubmatch(e.Name())
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("invalid migration filename: %s", e.Name())
|
||||
}
|
||||
version, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid version in %s: %w", e.Name(), err)
|
||||
}
|
||||
|
||||
content, err := fs.ReadFile(fsys, dir+"/"+e.Name())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s: %w", e.Name(), err)
|
||||
}
|
||||
|
||||
migrations = append(migrations, Migration{
|
||||
Version: version,
|
||||
Name: m[2],
|
||||
SQL: string(content),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
})
|
||||
|
||||
// Sanity check: versions must be unique and contiguous starting at 1.
|
||||
for i, m := range migrations {
|
||||
if m.Version != i+1 {
|
||||
return nil, fmt.Errorf("migration version gap: expected %d, got %d (%s)",
|
||||
i+1, m.Version, m.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// Migrator applies migrations to a database.
|
||||
type Migrator struct {
|
||||
DB *sql.DB
|
||||
Dialect Dialect
|
||||
Migrations []Migration
|
||||
|
||||
// CreateTableSQL is the dialect-specific DDL for the schema_migrations
|
||||
// table. Postgres and SQLite need slightly different syntax.
|
||||
CreateTableSQL string
|
||||
}
|
||||
|
||||
// Migrate applies all pending migrations.
|
||||
func (m *Migrator) Migrate(ctx context.Context) error {
|
||||
if _, err := m.DB.ExecContext(ctx, m.CreateTableSQL); err != nil {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
applied, err := m.appliedVersions(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, mig := range m.Migrations {
|
||||
if applied[mig.Version] {
|
||||
continue
|
||||
}
|
||||
if err := m.apply(ctx, mig); err != nil {
|
||||
return fmt.Errorf("apply migration %d (%s): %w", mig.Version, mig.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) appliedVersions(ctx context.Context) (map[int]bool, error) {
|
||||
rows, err := m.DB.QueryContext(ctx, `SELECT version FROM pase_schema_migrations`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read schema_migrations: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
applied := make(map[int]bool)
|
||||
for rows.Next() {
|
||||
var v int
|
||||
if err := rows.Scan(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applied[v] = true
|
||||
}
|
||||
return applied, rows.Err()
|
||||
}
|
||||
|
||||
func (m *Migrator) apply(ctx context.Context, mig Migration) error {
|
||||
tx, err := m.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Allow multi-statement migrations.
|
||||
for _, stmt := range splitStatements(mig.SQL) {
|
||||
if strings.TrimSpace(stmt) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, stmt); err != nil {
|
||||
return fmt.Errorf("exec statement: %w\n--SQL--\n%s", err, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx,
|
||||
m.Dialect.Rebind(`INSERT INTO pase_schema_migrations (version, name, applied_at)
|
||||
VALUES (?, ?, ?)`),
|
||||
mig.Version, mig.Name, time.Now().UTC())
|
||||
if err != nil {
|
||||
return fmt.Errorf("record migration: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// splitStatements is a naive splitter on semicolons. It works for our
|
||||
// migrations because we control them. Don't use this on user-supplied SQL.
|
||||
func splitStatements(sql string) []string {
|
||||
return strings.Split(sql, ";")
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue