diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..815523d --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.juancwu.dev/juancwu/pase + +go 1.26.2 + +require github.com/oklog/ulid/v2 v2.1.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4b2c1c3 --- /dev/null +++ b/go.sum @@ -0,0 +1,3 @@ +github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= +github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= +github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= diff --git a/store/dialect.go b/store/dialect.go new file mode 100644 index 0000000..116fb74 --- /dev/null +++ b/store/dialect.go @@ -0,0 +1,70 @@ +package store + +import ( + "strconv" + "strings" +) + +// Dialect captures the small syntactic differences between SQL backends. +// All queries are written with ? placeholders; Rebind rewrites them as +// needed for the target driver. Semantic differences (e.g., RETURNING +// support) are not the Dialect's concern — they belong in separate Store +// implementations. +type Dialect interface { + Rebind(query string) string +} + +// SQLiteDialect leaves queries unchanged; database/sql + the SQLite driver +// already accept ? placeholders. +type SQLiteDialect struct{} + +func (SQLiteDialect) Rebind(query string) string { return query } + +// PostgresDialect rewrites ? placeholders into $1, $2, ... while leaving +// any ? characters that appear inside single-quoted string literals or +// double-quoted identifiers untouched. +type PostgresDialect struct{} + +func (PostgresDialect) Rebind(query string) string { + var b strings.Builder + b.Grow(len(query) + 8) + + n := 0 + i := 0 + for i < len(query) { + c := query[i] + switch c { + case '\'', '"': + // Copy the entire quoted span verbatim, handling doubled-quote escapes. + quote := c + b.WriteByte(c) + i++ + for i < len(query) { + if query[i] == quote { + if i+1 < len(query) && query[i+1] == quote { + // Escaped quote: '' or "". Copy both bytes and continue. + b.WriteByte(quote) + b.WriteByte(quote) + i += 2 + continue + } + b.WriteByte(quote) + i++ + break + } + b.WriteByte(query[i]) + i++ + } + case '?': + n++ + b.WriteByte('$') + b.WriteString(strconv.Itoa(n)) + i++ + default: + b.WriteByte(c) + i++ + } + } + + return b.String() +} diff --git a/store/dialect_test.go b/store/dialect_test.go new file mode 100644 index 0000000..82eaf2f --- /dev/null +++ b/store/dialect_test.go @@ -0,0 +1,68 @@ +package store + +import "testing" + +func TestSQLiteDialect_Rebind_passthrough(t *testing.T) { + d := SQLiteDialect{} + in := `SELECT * FROM pase_users WHERE email = ? AND status = ?` + if got := d.Rebind(in); got != in { + t.Errorf("SQLite Rebind should be passthrough.\nin: %s\ngot: %s", in, got) + } +} + +func TestPostgresDialect_Rebind(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + { + name: "no placeholders", + in: `SELECT 1`, + want: `SELECT 1`, + }, + { + name: "single placeholder", + in: `SELECT * FROM pase_users WHERE id = ?`, + want: `SELECT * FROM pase_users WHERE id = $1`, + }, + { + name: "multiple placeholders", + in: `INSERT INTO pase_users (id, email, status) VALUES (?, ?, ?)`, + want: `INSERT INTO pase_users (id, email, status) VALUES ($1, $2, $3)`, + }, + { + name: "question mark inside single-quoted literal is preserved", + in: `SELECT * FROM t WHERE name = 'who?' AND id = ?`, + want: `SELECT * FROM t WHERE name = 'who?' AND id = $1`, + }, + { + name: "escaped single quote inside literal", + in: `SELECT * FROM t WHERE name = 'O''Reilly?' AND id = ?`, + want: `SELECT * FROM t WHERE name = 'O''Reilly?' AND id = $1`, + }, + { + name: "question mark inside double-quoted identifier is preserved", + in: `SELECT "weird?col" FROM t WHERE id = ?`, + want: `SELECT "weird?col" FROM t WHERE id = $1`, + }, + { + name: "escaped double quote inside identifier", + in: `SELECT "a""?b" FROM t WHERE id = ?`, + want: `SELECT "a""?b" FROM t WHERE id = $1`, + }, + { + name: "mix of literals and placeholders", + in: `UPDATE t SET name = 'foo?', other = ? WHERE id = ? AND tag = 'x?'`, + want: `UPDATE t SET name = 'foo?', other = $1 WHERE id = $2 AND tag = 'x?'`, + }, + } + d := PostgresDialect{} + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := d.Rebind(tc.in); got != tc.want { + t.Errorf("Rebind mismatch\nin: %s\nwant: %s\ngot: %s", tc.in, tc.want, got) + } + }) + } +} diff --git a/store/jsonb.go b/store/jsonb.go new file mode 100644 index 0000000..0895387 --- /dev/null +++ b/store/jsonb.go @@ -0,0 +1,54 @@ +package store + +import ( + "database/sql/driver" + "fmt" +) + +// JSONB is a nullable JSON value backed by raw bytes. +// Scans from jsonb (Postgres) or TEXT (SQLite). Empty means SQL NULL. +type JSONB []byte + +// Scan implements sql.Scanner. +func (j *JSONB) Scan(src any) error { + if src == nil { + *j = nil + return nil + } + switch v := src.(type) { + case []byte: + // Copy: drivers may reuse the buffer between rows. + *j = append((*j)[:0], v...) + case string: + *j = []byte(v) + default: + return fmt.Errorf("pase: cannot scan %T into JSONB", src) + } + return nil +} + +// Value implements driver.Valuer. +func (j JSONB) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return []byte(j), nil +} + +// MarshalJSON makes JSONB transparent in API responses. +func (j JSONB) MarshalJSON() ([]byte, error) { + if len(j) == 0 { + return []byte("null"), nil + } + return []byte(j), nil +} + +// UnmarshalJSON lets you decode directly into JSONB. +func (j *JSONB) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *j = nil + return nil + } + *j = append((*j)[:0], data...) + return nil +} diff --git a/store/migrate.go b/store/migrate.go new file mode 100644 index 0000000..c937df6 --- /dev/null +++ b/store/migrate.go @@ -0,0 +1,157 @@ +package store + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "regexp" + "sort" + "strconv" + "strings" + "time" +) + +type Migration struct { + Version int + Name string + SQL string +} + +var migrationFilenameRE = regexp.MustCompile(`^(\d+)_([a-z0-9_]+)\.sql$`) + +// LoadMigrations reads SQL files from an embedded FS. +// Files must be named like "001_initial.sql". +func LoadMigrations(fsys embed.FS, dir string) ([]Migration, error) { + entries, err := fs.ReadDir(fsys, dir) + if err != nil { + return nil, fmt.Errorf("read migration dir: %w", err) + } + + var migrations []Migration + for _, e := range entries { + if e.IsDir() { + continue + } + m := migrationFilenameRE.FindStringSubmatch(e.Name()) + if m == nil { + return nil, fmt.Errorf("invalid migration filename: %s", e.Name()) + } + version, err := strconv.Atoi(m[1]) + if err != nil { + return nil, fmt.Errorf("invalid version in %s: %w", e.Name(), err) + } + + content, err := fs.ReadFile(fsys, dir+"/"+e.Name()) + if err != nil { + return nil, fmt.Errorf("read %s: %w", e.Name(), err) + } + + migrations = append(migrations, Migration{ + Version: version, + Name: m[2], + SQL: string(content), + }) + } + + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].Version < migrations[j].Version + }) + + // Sanity check: versions must be unique and contiguous starting at 1. + for i, m := range migrations { + if m.Version != i+1 { + return nil, fmt.Errorf("migration version gap: expected %d, got %d (%s)", + i+1, m.Version, m.Name) + } + } + + return migrations, nil +} + +// Migrator applies migrations to a database. +type Migrator struct { + DB *sql.DB + Dialect Dialect + Migrations []Migration + + // CreateTableSQL is the dialect-specific DDL for the schema_migrations + // table. Postgres and SQLite need slightly different syntax. + CreateTableSQL string +} + +// Migrate applies all pending migrations. +func (m *Migrator) Migrate(ctx context.Context) error { + if _, err := m.DB.ExecContext(ctx, m.CreateTableSQL); err != nil { + return fmt.Errorf("create schema_migrations: %w", err) + } + + applied, err := m.appliedVersions(ctx) + if err != nil { + return err + } + + for _, mig := range m.Migrations { + if applied[mig.Version] { + continue + } + if err := m.apply(ctx, mig); err != nil { + return fmt.Errorf("apply migration %d (%s): %w", mig.Version, mig.Name, err) + } + } + + return nil +} + +func (m *Migrator) appliedVersions(ctx context.Context) (map[int]bool, error) { + rows, err := m.DB.QueryContext(ctx, `SELECT version FROM pase_schema_migrations`) + if err != nil { + return nil, fmt.Errorf("read schema_migrations: %w", err) + } + defer rows.Close() + + applied := make(map[int]bool) + for rows.Next() { + var v int + if err := rows.Scan(&v); err != nil { + return nil, err + } + applied[v] = true + } + return applied, rows.Err() +} + +func (m *Migrator) apply(ctx context.Context, mig Migration) error { + tx, err := m.DB.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + // Allow multi-statement migrations. + for _, stmt := range splitStatements(mig.SQL) { + if strings.TrimSpace(stmt) == "" { + continue + } + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("exec statement: %w\n--SQL--\n%s", err, stmt) + } + } + + _, err = tx.ExecContext(ctx, + m.Dialect.Rebind(`INSERT INTO pase_schema_migrations (version, name, applied_at) + VALUES (?, ?, ?)`), + mig.Version, mig.Name, time.Now().UTC()) + if err != nil { + return fmt.Errorf("record migration: %w", err) + } + + return tx.Commit() +} + +// splitStatements is a naive splitter on semicolons. It works for our +// migrations because we control them. Don't use this on user-supplied SQL. +func splitStatements(sql string) []string { + return strings.Split(sql, ";") +} diff --git a/store/models.go b/store/models.go new file mode 100644 index 0000000..9888ccc --- /dev/null +++ b/store/models.go @@ -0,0 +1,245 @@ +package store + +import ( + "encoding/json" + "fmt" + "time" +) + +type UserStatus string + +const ( + StatusActive UserStatus = "active" + StatusDeactivated UserStatus = "deactivated" + StatusLocked UserStatus = "locked" + StatusBanned UserStatus = "banned" + StatusPendingDeletion UserStatus = "pending_deletion" +) + +type User struct { + ID string + Email string + EmailVerifiedAt NullTime + Username string + UsernameNormalized string + DisplayName string + ProfileImageURL string + + Status UserStatus + StatusReason string + StatusChangedAt NullTime + StatusExpiresAt NullTime + + FailedLoginCount int + LastFailedLoginAt NullTime + + CreatedAt time.Time + UpdatedAt time.Time +} + +type Permission struct { + ID string + Name string + Description NullString + CreatedAt time.Time + UpdatedAt time.Time +} + +type Role struct { + ID string + Name string + Description NullString + CreatedAt time.Time + UpdatedAt time.Time +} + +type RolePermission struct { + RoleID string + PermissionID string + CreatedAt time.Time +} + +type PermissionEffect string + +const ( + PermissionAllow PermissionEffect = "allow" + PermissionDeny PermissionEffect = "deny" +) + +type UserPermission struct { + UserID string + PermissionID string + Effect PermissionEffect + CreatedAt time.Time +} + +type Session struct { + IDHash string + UserID string + ExpiresAt time.Time + LastUsedAt time.Time + UserAgent NullString + IPAddress NullString + CreatedAt time.Time +} + +type TokenPurpose string + +const ( + TokenPurposeMagicLink TokenPurpose = "magic_link" + TokenPurposePasswordReset TokenPurpose = "password_reset" + TokenPurposeEmailVerify TokenPurpose = "email_verify" + TokenPurposeEmailChange TokenPurpose = "email_change" +) + +type Token struct { + ID string + UserID string + Purpose TokenPurpose + HashedValue string + Payload JSONB + ExpiresAt time.Time + ConsumedAt time.Time + CreatedAt time.Time +} + +type CredentialType string + +const ( + CredentialPassword CredentialType = "password" + CredentialPasskey CredentialType = "passkey" + CredentialTOTP CredentialType = "totp" + CredentialOAuth CredentialType = "oauth" +) + +type Credential struct { + ID string + UserID string + Type CredentialType + + // Used by passkeys (credential ID) and OAuth (provider account id). + // Null for password and TOTP. + Identifier NullString + + // Used by OAuth: "google", "github", etc. Null otherwise. + Provider NullString + + // The actual secret material. Format depends on type: + // password: argon2id hash string + // passkey: COSE public key (base64) + // totp: encrypted shared secret + // oauth: null (tokens go in `data`) + Secret NullString + + // Type-specific fields that don't fit elsewhere: + // passkey: { sign_count, transports, aaguid, backup_eligible } + // totp: { algorithm, digits, period } + // oauth: { access_token, refresh_token, expires_at, scope } + Data JSONB + + // Human-friendly label, useful for UI ("My iPhone", "YubiKey 5C"). + // Especially valuable for passkeys where users have multiple. + Name NullString + + LastUsedAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type PasskeyData struct { + SignCount uint32 `json:"sign_count"` + Transports []string `json:"transports"` + AAGUID string `json:"aaguid"` + BackupEligible bool `json:"backup_eligible"` +} + +type OAuthData struct { + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + Scope string `json:"scope,omitempty"` +} + +type TOTPData struct { + Algorithm string `json:"algorithm"` + Digits int `json:"digits"` + Period int `json:"period"` +} + +func (c *Credential) PasskeyData() (*PasskeyData, error) { + if c.Type != CredentialPasskey { + return nil, fmt.Errorf("pase: credential is %s, not passkey", c.Type) + } + if len(c.Data) == 0 { + return &PasskeyData{}, nil + } + var d PasskeyData + if err := json.Unmarshal(c.Data, &d); err != nil { + return nil, fmt.Errorf("pase: decode passkey data: %w", err) + } + return &d, nil +} + +func (c *Credential) SetPasskeyData(d *PasskeyData) error { + if c.Type != CredentialPasskey { + return fmt.Errorf("pase: credential is %s, not passkey", c.Type) + } + b, err := json.Marshal(d) + if err != nil { + return fmt.Errorf("pase: %w", err) + } + c.Data = b + return nil +} + +func (c *Credential) OAuthData() (*OAuthData, error) { + if c.Type != CredentialOAuth { + return nil, fmt.Errorf("pase: credential is %s, not oauth", c.Type) + } + if len(c.Data) == 0 { + return &OAuthData{}, nil + } + var d OAuthData + if err := json.Unmarshal(c.Data, &d); err != nil { + return nil, fmt.Errorf("pase: decode oauth data: %w", err) + } + return &d, nil +} + +func (c *Credential) SetOAuthData(d *OAuthData) error { + if c.Type != CredentialOAuth { + return fmt.Errorf("pase: credential is %s, not oauth", c.Type) + } + b, err := json.Marshal(d) + if err != nil { + return fmt.Errorf("pase: %w", err) + } + c.Data = b + return nil +} + +func (c *Credential) TOTPData() (*TOTPData, error) { + if c.Type != CredentialTOTP { + return nil, fmt.Errorf("pase: credential is %s, not totp", c.Type) + } + if len(c.Data) == 0 { + return &TOTPData{}, nil + } + var d TOTPData + if err := json.Unmarshal(c.Data, &d); err != nil { + return nil, fmt.Errorf("pase: decode totp data: %w", err) + } + return &d, nil +} + +func (c *Credential) SetTOTPData(d *TOTPData) error { + if c.Type != CredentialTOTP { + return fmt.Errorf("pase: credential is %s, not totp", c.Type) + } + b, err := json.Marshal(d) + if err != nil { + return fmt.Errorf("pase: %w", err) + } + c.Data = b + return nil +} diff --git a/store/nullables.go b/store/nullables.go new file mode 100644 index 0000000..fe9073a --- /dev/null +++ b/store/nullables.go @@ -0,0 +1,48 @@ +package store + +import ( + "database/sql" + "encoding/json" +) + +type NullTime sql.NullTime + +func (nt NullTime) MarshalJSON() ([]byte, error) { + if nt.Valid { + return json.Marshal(nt.Time) + } + return []byte("null"), nil +} + +func (nt *NullTime) UnmarshalJSON(data []byte) error { + if string(data) == "null" || string(data) == `""` { + nt.Valid = false + return nil + } + if err := json.Unmarshal(data, &nt.Time); err != nil { + return err + } + nt.Valid = true + return nil +} + +type NullString sql.NullString + +func (ns NullString) MarshalJSON() ([]byte, error) { + if ns.Valid { + return json.Marshal(ns.String) + } + return []byte("null"), nil +} + +func (ns *NullString) UnmarshalJSON(data []byte) error { + if string(data) == "null" || string(data) == `""` { + ns.Valid = false + return nil + } + if err := json.Unmarshal(data, &ns.String); err != nil { + return err + } + ns.Valid = true + return nil +} diff --git a/store/store.go b/store/store.go new file mode 100644 index 0000000..af69f47 --- /dev/null +++ b/store/store.go @@ -0,0 +1,52 @@ +package store + +import ( + "context" + "time" +) + +type Store interface { + CreateUser(ctx context.Context, u *User) error + GetUserByID(ctx context.Context, id string) (*User, error) + GetUserByEmail(ctx context.Context, email string) (*User, error) + GetUserByUsername(ctx context.Context, username string) (*User, error) + UpdateUser(ctx context.Context, u *User) error + + UpsertCredential(ctx context.Context, c *Credential) error + GetCredential(ctx context.Context, userID string, t CredentialType) (*Credential, error) + DeleteCredential(ctx context.Context, userID string, t CredentialType) error + + CreateSession(ctx context.Context, s *Session) error + GetSession(ctx context.Context, id string) (*Session, error) + DeleteSession(ctx context.Context, id string) error + DeleteUserSessions(ctx context.Context, userID string) error + + CreateToken(ctx context.Context, t *Token) error + ConsumeToken(ctx context.Context, hashedValue string, purpose TokenPurpose) (*Token, error) + DeleteExpiredTokens(ctx context.Context, before time.Time) error + + CreatePermission(ctx context.Context, p *Permission) error + GetPermissionByName(ctx context.Context, name string) (*Permission, error) + ListPermissions(ctx context.Context) ([]*Permission, error) + DeletePermission(ctx context.Context, id string) error + + CreateRole(ctx context.Context, r *Role) error + GetRoleByName(ctx context.Context, name string) (*Role, error) + ListRoles(ctx context.Context) ([]*Role, error) + DeleteRole(ctx context.Context, id string) error + + AssignPermissionToRole(ctx context.Context, roleID, permissionID string) error + RevokePermissionFromRole(ctx context.Context, roleID, permissionID string) error + ListRolePermissions(ctx context.Context, roleID string) ([]*Permission, error) + + AssignRoleToUser(ctx context.Context, userID, roleID string) error + RevokeRoleFromUser(ctx context.Context, userID, roleID string) error + ListUserRoles(ctx context.Context, userID string) ([]*Role, error) + + SetUserPermission(ctx context.Context, userID, permissionID string, effect PermissionEffect) error + DeleteUserPermission(ctx context.Context, userID, permissionID string) error + ListUserPermissions(ctx context.Context, userID string) ([]*UserPermission, error) + + UserHasPermissionViaRole(ctx context.Context, userID, permissionName string) (bool, error) + ListEffectivePermissions(ctx context.Context, userID string) ([]*Permission, error) +} diff --git a/store/ulid.go b/store/ulid.go new file mode 100644 index 0000000..8d54801 --- /dev/null +++ b/store/ulid.go @@ -0,0 +1,21 @@ +package store + +import "github.com/oklog/ulid/v2" + +// IDGenerator produces unique identifiers for entities. +type IDGenerator interface { + // NewID generates a new unique identifier for entities. + NewID() string +} + +// defaultIDGenerator generates ULIDs using crypto/rand entropy. +type defaultIDGenerator struct{} + +func (defaultIDGenerator) NewID() string { + return ulid.Make().String() +} + +// DefaultIDGenerator returns the standard ULID-based generator. +func DefaultIDGenerator() IDGenerator { + return defaultIDGenerator{} +}