128 lines
4 KiB
Go
128 lines
4 KiB
Go
// Package sqlite provides a SQLite-backed implementation of store.Store.
|
|
// Tested with "modernc.org/sqlite" other sqlite drivers may yield different results.
|
|
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.juancwu.dev/juancwu/pase/store"
|
|
)
|
|
|
|
// Store is a SQLite-backed implementation of store.Store.
|
|
//
|
|
// All queries are written with ? placeholders and rebound through the dialect
|
|
// at construction time. Every method is intended to be a single round-trip;
|
|
// methods that need atomicity (e.g. ConsumeToken) use a single statement or
|
|
// an explicit transaction — never read-modify-write through Go.
|
|
type Store struct {
|
|
db *sql.DB
|
|
dialect store.Dialect
|
|
|
|
// Pre-rebound query strings. One field per query keeps the SQL text out
|
|
// of method bodies and makes it easy to audit by reading the struct.
|
|
q store.Queries
|
|
}
|
|
|
|
// NewStore wraps an existing *sql.DB. It does not own the connection pool;
|
|
// the caller is responsible for opening, configuring, and closing it.
|
|
//
|
|
// Foreign key enforcement is required; configure it on the connection (e.g.
|
|
// via DSN `?_pragma=foreign_keys(1)` for modernc.org/sqlite, or
|
|
// `?_fk=1` for mattn/go-sqlite3) before passing the *sql.DB in.
|
|
func NewStore(db *sql.DB) *Store {
|
|
d := store.SQLiteDialect{}
|
|
s := &Store{db: db, dialect: d}
|
|
|
|
s.q = store.CanonicalQueries.Rebind(d)
|
|
|
|
return s
|
|
}
|
|
|
|
// CreateUser inserts a user row. Timestamps will be overwritten if they are set.
|
|
func (s *Store) CreateUser(ctx context.Context, u *store.User) error {
|
|
now := time.Now()
|
|
u.CreatedAt = now
|
|
u.UpdatedAt = now
|
|
|
|
_, err := s.db.ExecContext(ctx, s.q.CreateUser,
|
|
u.ID, u.Email, u.EmailVerifiedAt,
|
|
u.Username, u.UsernameNormalized,
|
|
u.DisplayName, u.ProfileImageURL,
|
|
u.Status, u.StatusReason,
|
|
u.StatusChangedAt, u.StatusExpiresAt,
|
|
u.FailedLoginCount, u.LastFailedLoginAt,
|
|
u.CreatedAt, u.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
errStr := err.Error()
|
|
if strings.Contains(errStr, "pase_users.email") {
|
|
return fmt.Errorf("pase/sqlite: create user: %w", store.ErrEmailAlreadyExists)
|
|
}
|
|
if strings.Contains(errStr, "pase_users.username_normalized") {
|
|
return fmt.Errorf("pase/sqlite: create user: %w", store.ErrUsernameAlreadyExists)
|
|
}
|
|
return fmt.Errorf("pase/sqlite: create user: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetUserByID returns the user with the given id.
|
|
func (s *Store) GetUserByID(ctx context.Context, id string) (*store.User, error) {
|
|
row := s.db.QueryRowContext(ctx, s.q.GetUserByID, id)
|
|
|
|
u, err := s.scanUser(row)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, fmt.Errorf("pase/sqlite: get user by id: %w", store.ErrUserNotFound)
|
|
}
|
|
return nil, fmt.Errorf("pase/sqlite: get user by id: %w", err)
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// GetUserByEmail returns the user with the given email.
|
|
func (s *Store) GetUserByEmail(ctx context.Context, email string) (*store.User, error) {
|
|
row := s.db.QueryRowContext(ctx, s.q.GetUserByEmail, email)
|
|
u, err := s.scanUser(row)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, fmt.Errorf("pase/sqlite: get user by email: %w", store.ErrUserNotFound)
|
|
}
|
|
return nil, fmt.Errorf("pase/sqlite: get user by email: %w", err)
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// GetUserByUsername returns the user with the given username.
|
|
func (s *Store) GetUserByUsername(ctx context.Context, normalizedUsername string) (*store.User, error) {
|
|
row := s.db.QueryRowContext(ctx, s.q.GetUserByUsername, normalizedUsername)
|
|
u, err := s.scanUser(row)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, fmt.Errorf("pase/sqlite: get user by username: %w", store.ErrUserNotFound)
|
|
}
|
|
return nil, fmt.Errorf("pase/sqlite: get user by username: %w", err)
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
func (s *Store) scanUser(row *sql.Row) (*store.User, error) {
|
|
var u store.User
|
|
err := row.Scan(
|
|
&u.ID, &u.Email, &u.EmailVerifiedAt,
|
|
&u.Username, &u.UsernameNormalized, &u.DisplayName, &u.ProfileImageURL,
|
|
&u.Status, &u.StatusReason,
|
|
&u.StatusChangedAt, &u.StatusExpiresAt,
|
|
&u.FailedLoginCount, &u.LastFailedLoginAt,
|
|
&u.CreatedAt, &u.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|