feat: drop sqlite support
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m27s
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m27s
This commit is contained in:
parent
5e00060421
commit
da718427bd
27 changed files with 1296 additions and 115 deletions
|
|
@ -5,8 +5,8 @@ APP_URL=http://127.0.0.1:7331 # required for base url in email links etc, port i
|
||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=9000
|
PORT=9000
|
||||||
|
|
||||||
DB_DRIVER=sqlite
|
# PostgreSQL 17 connection string (libpq URL or DSN). Required.
|
||||||
DB_CONNECTION="./data/local.db?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)"
|
DB_CONNECTION="postgres://budgit:budgit@127.0.0.1:5432/budgit?sslmode=disable"
|
||||||
|
|
||||||
JWT_SECRET=
|
JWT_SECRET=
|
||||||
# Go duration format
|
# Go duration format
|
||||||
|
|
|
||||||
|
|
@ -64,11 +64,11 @@ Components live in `internal/ui/components/` — button, input, checkbox, dialog
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
App reads from `.env` file via `godotenv`. Key vars: `APP_ENV`, `APP_URL`, `DB_DRIVER` (pgx/sqlite), `DB_CONNECTION`, `JWT_SECRET`, `PORT`. See `internal/config/config.go` for all fields.
|
App reads from `.env` file via `godotenv`. Key vars: `APP_ENV`, `APP_URL`, `DB_CONNECTION` (libpq URL/DSN), `JWT_SECRET`, `PORT`. See `internal/config/config.go` for all fields.
|
||||||
|
|
||||||
## Database
|
## Database
|
||||||
|
|
||||||
PostgreSQL (pgx driver) or SQLite. Migrations auto-run on startup from `internal/db/migrations/` (Goose SQL format, embedded via `go:embed`). 8 migration files covering users, tokens, profiles, spaces, shopping lists, tags, expenses, invitations.
|
PostgreSQL 17 only (pgx driver). Migrations auto-run on startup from `internal/db/migrations/` (Goose SQL format, embedded via `go:embed`). Tests run against an ephemeral Postgres container via `task test`; set `BUDGIT_TEST_POSTGRES_URL` to point at a long-lived instance instead.
|
||||||
|
|
||||||
# templui Components
|
# templui Components
|
||||||
|
|
||||||
|
|
|
||||||
13
Taskfile.yml
13
Taskfile.yml
|
|
@ -26,18 +26,13 @@ tasks:
|
||||||
cmds:
|
cmds:
|
||||||
- echo "Starting app..."
|
- echo "Starting app..."
|
||||||
- task --parallel tailwind-watch templ
|
- task --parallel tailwind-watch templ
|
||||||
# Testing
|
# Testing — TestMain in each db-touching package auto-spins an ephemeral
|
||||||
|
# postgres:17-alpine container; honors BUDGIT_TEST_POSTGRES_URL when set
|
||||||
|
# (CI uses that path to point at a long-lived service).
|
||||||
test:
|
test:
|
||||||
desc: Run tests (SQLite only)
|
desc: Run tests (auto-starts an ephemeral PostgreSQL 17 container if needed)
|
||||||
cmds:
|
cmds:
|
||||||
- set -o pipefail && go test ./... -json | tparse -all
|
- set -o pipefail && go test ./... -json | tparse -all
|
||||||
test:integration:
|
|
||||||
desc: Run tests against both SQLite and PostgreSQL
|
|
||||||
cmds:
|
|
||||||
- docker run --name budgit-test-pg -d -p 5433:5432 -e POSTGRES_USER=budgit_test -e POSTGRES_PASSWORD=testpass -e POSTGRES_DB=budgit_test postgres:17-alpine
|
|
||||||
- defer: docker rm -f budgit-test-pg
|
|
||||||
- cmd: sleep 3
|
|
||||||
- cmd: BUDGIT_TEST_POSTGRES_URL="postgres://budgit_test:testpass@localhost:5433/budgit_test?sslmode=disable" set -o pipefail && go test ./... -json | tparse -all
|
|
||||||
# Production build
|
# Production build
|
||||||
build:
|
build:
|
||||||
desc: Build production binary
|
desc: Build production binary
|
||||||
|
|
|
||||||
4
go.mod
4
go.mod
|
|
@ -15,8 +15,8 @@ require (
|
||||||
github.com/pressly/goose/v3 v3.26.0
|
github.com/pressly/goose/v3 v3.26.0
|
||||||
github.com/shopspring/decimal v1.4.0
|
github.com/shopspring/decimal v1.4.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/templui/templui v1.9.5
|
||||||
github.com/wneessen/go-mail v0.7.2
|
github.com/wneessen/go-mail v0.7.2
|
||||||
modernc.org/sqlite v1.40.1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
@ -62,7 +62,6 @@ require (
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/segmentio/asm v1.2.0 // indirect
|
github.com/segmentio/asm v1.2.0 // indirect
|
||||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||||
github.com/templui/templui v1.9.5 // indirect
|
|
||||||
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d // indirect
|
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d // indirect
|
||||||
github.com/vertica/vertica-sql-go v1.3.3 // indirect
|
github.com/vertica/vertica-sql-go v1.3.3 // indirect
|
||||||
github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 // indirect
|
github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 // indirect
|
||||||
|
|
@ -87,6 +86,7 @@ require (
|
||||||
modernc.org/libc v1.66.10 // indirect
|
modernc.org/libc v1.66.10 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
|
modernc.org/sqlite v1.40.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
tool (
|
tool (
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,12 @@ type App struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg *config.Config) (*App, error) {
|
func New(cfg *config.Config) (*App, error) {
|
||||||
database, err := db.Init(cfg.DBDriver, cfg.DBConnection)
|
database, err := db.Init(cfg.DBConnection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RunMigrations(database.DB, cfg.DBDriver)
|
err = db.RunMigrations(database.DB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ type Config struct {
|
||||||
Host string
|
Host string
|
||||||
Port string
|
Port string
|
||||||
|
|
||||||
DBDriver string
|
|
||||||
DBConnection string
|
DBConnection string
|
||||||
|
|
||||||
JWTSecret string
|
JWTSecret string
|
||||||
|
|
@ -53,8 +52,7 @@ func Load(version string) *Config {
|
||||||
Host: envString("HOST", "127.0.0.1"),
|
Host: envString("HOST", "127.0.0.1"),
|
||||||
Port: envString("PORT", "9000"),
|
Port: envString("PORT", "9000"),
|
||||||
|
|
||||||
DBDriver: envString("DB_DRIVER", "sqlite"),
|
DBConnection: envRequired("DB_CONNECTION"),
|
||||||
DBConnection: envString("DB_CONNECTION", "./data/local.db?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)"),
|
|
||||||
|
|
||||||
JWTSecret: envRequired("JWT_SECRET"),
|
JWTSecret: envRequired("JWT_SECRET"),
|
||||||
JWTExpiry: envDuration("JWT_EXPIRY", 168*time.Hour), // 7 days default
|
JWTExpiry: envDuration("JWT_EXPIRY", 168*time.Hour), // 7 days default
|
||||||
|
|
|
||||||
|
|
@ -3,25 +3,16 @@ package db
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(driver, connection string) (*sqlx.DB, error) {
|
// Init opens a PostgreSQL connection pool. The connection string must be a
|
||||||
if driver == "sqlite" {
|
// libpq-style URL or DSN supported by the pgx stdlib driver.
|
||||||
dir := filepath.Dir(connection)
|
func Init(connection string) (*sqlx.DB, error) {
|
||||||
err := os.MkdirAll(dir, 0755)
|
db, err := sqlx.Connect("pgx", connection)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create data directory: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := sqlx.Connect(driver, connection)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -30,10 +21,9 @@ func Init(driver, connection string) (*sqlx.DB, error) {
|
||||||
db.SetMaxIdleConns(5)
|
db.SetMaxIdleConns(5)
|
||||||
db.SetConnMaxLifetime(5 * time.Minute)
|
db.SetConnMaxLifetime(5 * time.Minute)
|
||||||
|
|
||||||
slog.Info("database connected", "driver", driver)
|
slog.Info("database connected", "driver", "pgx")
|
||||||
|
|
||||||
err = db.Ping()
|
if err := db.Ping(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,22 +9,8 @@ import (
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dialectMap = map[string]string{
|
func setupGoose() error {
|
||||||
"sqlite": "sqlite3",
|
if err := goose.SetDialect("postgres"); err != nil {
|
||||||
"pgx": "postgres",
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDialect(driver string) string {
|
|
||||||
dialect, ok := dialectMap[driver]
|
|
||||||
if ok {
|
|
||||||
return dialect
|
|
||||||
}
|
|
||||||
return driver
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupGoose(driver string) error {
|
|
||||||
err := goose.SetDialect(getDialect(driver))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set dialect: %w", err)
|
return fmt.Errorf("failed to set dialect: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -34,22 +20,18 @@ func setupGoose(driver string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
goose.SetBaseFS(migrationsDir)
|
goose.SetBaseFS(migrationsDir)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunMigrations(db *sql.DB, driver string) error {
|
func RunMigrations(db *sql.DB) error {
|
||||||
err := setupGoose(driver)
|
if err := setupGoose(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = goose.Up(db, ".")
|
if err := goose.Up(db, "."); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to run migrations: %w", err)
|
return fmt.Errorf("failed to run migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("migrations completed successfully")
|
slog.Info("migrations completed successfully")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ CREATE TABLE space_audit_logs (
|
||||||
action TEXT NOT NULL,
|
action TEXT NOT NULL,
|
||||||
target_user_id TEXT REFERENCES users(id) ON DELETE SET NULL,
|
target_user_id TEXT REFERENCES users(id) ON DELETE SET NULL,
|
||||||
target_email TEXT,
|
target_email TEXT,
|
||||||
metadata JSONB NOT NULL DEFAULT '{}',
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ CREATE TABLE transaction_audit_logs (
|
||||||
transaction_id TEXT NOT NULL,
|
transaction_id TEXT NOT NULL,
|
||||||
actor_id TEXT REFERENCES users(id) ON DELETE SET NULL,
|
actor_id TEXT REFERENCES users(id) ON DELETE SET NULL,
|
||||||
action TEXT NOT NULL,
|
action TEXT NOT NULL,
|
||||||
metadata JSONB NOT NULL DEFAULT '{}',
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
-- +goose Up
|
||||||
|
-- +goose StatementBegin
|
||||||
|
-- The account-scoped activity feeds filter audit rows by metadata->>'account_id'.
|
||||||
|
-- A partial expression index is the right shape for this access pattern in
|
||||||
|
-- PostgreSQL 17: it is small (only the rows where the field exists), uses a
|
||||||
|
-- standard B-tree (cheap equality + ORDER BY created_at), and avoids the bloat
|
||||||
|
-- of a full GIN over the metadata document.
|
||||||
|
|
||||||
|
CREATE INDEX idx_space_audit_logs_account_id
|
||||||
|
ON space_audit_logs ((metadata->>'account_id'), created_at DESC)
|
||||||
|
WHERE action LIKE 'account.%';
|
||||||
|
|
||||||
|
CREATE INDEX idx_transaction_audit_logs_account_id
|
||||||
|
ON transaction_audit_logs ((metadata->>'account_id'), created_at DESC)
|
||||||
|
WHERE metadata ? 'account_id';
|
||||||
|
-- +goose StatementEnd
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
-- +goose StatementBegin
|
||||||
|
DROP INDEX IF EXISTS idx_space_audit_logs_account_id;
|
||||||
|
DROP INDEX IF EXISTS idx_transaction_audit_logs_account_id;
|
||||||
|
-- +goose StatementEnd
|
||||||
9
internal/handler/main_test.go
Normal file
9
internal/handler/main_test.go
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) { testutil.PostgresMain(m) }
|
||||||
9
internal/repository/main_test.go
Normal file
9
internal/repository/main_test.go
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) { testutil.PostgresMain(m) }
|
||||||
124
internal/repository/space_audit_log_test.go
Normal file
124
internal/repository/space_audit_log_test.go
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeSpaceAuditLog(t *testing.T, repo SpaceAuditLogRepository, spaceID string, action model.SpaceAuditAction, actorID *string, metadata map[string]any, ts time.Time) *model.SpaceAuditLog {
|
||||||
|
t.Helper()
|
||||||
|
var meta []byte
|
||||||
|
if metadata != nil {
|
||||||
|
var err error
|
||||||
|
meta, err = json.Marshal(metadata)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
entry := &model.SpaceAuditLog{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
SpaceID: spaceID,
|
||||||
|
ActorID: actorID,
|
||||||
|
Action: action,
|
||||||
|
Metadata: meta,
|
||||||
|
CreatedAt: ts,
|
||||||
|
}
|
||||||
|
require.NoError(t, repo.Create(entry))
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpaceAuditLogRepository_CreateAndList(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
repo := NewSpaceAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
actor := testutil.CreateTestUserWithName(t, dbi.DB, "audit-actor@example.com", strPtr("Actor Name"))
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Audit Space")
|
||||||
|
|
||||||
|
base := time.Now().Add(-time.Hour)
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionRenamed, &actor.ID, map[string]any{"old_name": "A", "new_name": "B"}, base)
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionMemberInvited, &actor.ID, nil, base.Add(10*time.Minute))
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionDeleted, &actor.ID, map[string]any{"space_name": "Audit Space"}, base.Add(20*time.Minute))
|
||||||
|
|
||||||
|
count, err := repo.CountBySpace(space.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, count)
|
||||||
|
|
||||||
|
logs, err := repo.ListBySpace(space.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 3)
|
||||||
|
|
||||||
|
// Newest first.
|
||||||
|
assert.Equal(t, model.SpaceAuditActionDeleted, logs[0].Action)
|
||||||
|
assert.Equal(t, model.SpaceAuditActionMemberInvited, logs[1].Action)
|
||||||
|
assert.Equal(t, model.SpaceAuditActionRenamed, logs[2].Action)
|
||||||
|
|
||||||
|
// Actor join populated.
|
||||||
|
require.NotNil(t, logs[0].ActorName)
|
||||||
|
assert.Equal(t, "Actor Name", *logs[0].ActorName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpaceAuditLogRepository_ListBySpace_Pagination(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
repo := NewSpaceAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
actor := testutil.CreateTestUser(t, dbi.DB, "page@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Paged Space")
|
||||||
|
|
||||||
|
base := time.Now().Add(-time.Hour)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionRenamed, &actor.ID, nil, base.Add(time.Duration(i)*time.Minute))
|
||||||
|
}
|
||||||
|
|
||||||
|
page1, err := repo.ListBySpace(space.ID, 2, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, page1, 2)
|
||||||
|
|
||||||
|
page2, err := repo.ListBySpace(space.ID, 2, 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, page2, 2)
|
||||||
|
|
||||||
|
// No overlap between pages.
|
||||||
|
assert.NotEqual(t, page1[0].ID, page2[0].ID)
|
||||||
|
assert.NotEqual(t, page1[1].ID, page2[0].ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpaceAuditLogRepository_ListAccountEvents_FiltersByMetadata(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
repo := NewSpaceAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
actor := testutil.CreateTestUser(t, dbi.DB, "acct-filter@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Filter Space")
|
||||||
|
acct1 := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Account 1")
|
||||||
|
acct2 := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Account 2")
|
||||||
|
|
||||||
|
base := time.Now().Add(-time.Hour)
|
||||||
|
// Account 1 events
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionAccountCreated, &actor.ID, map[string]any{"account_id": acct1.ID, "account_name": "Account 1"}, base)
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionAccountRenamed, &actor.ID, map[string]any{"account_id": acct1.ID, "old_name": "Account 1", "new_name": "Renamed"}, base.Add(time.Minute))
|
||||||
|
// Account 2 event
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionAccountCreated, &actor.ID, map[string]any{"account_id": acct2.ID, "account_name": "Account 2"}, base.Add(2*time.Minute))
|
||||||
|
// Non-account event in same space — must NOT appear in account-scoped query
|
||||||
|
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionRenamed, &actor.ID, map[string]any{"old_name": "x", "new_name": "y"}, base.Add(3*time.Minute))
|
||||||
|
|
||||||
|
acct1Logs, err := repo.ListAccountEvents(acct1.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, acct1Logs, 2)
|
||||||
|
|
||||||
|
acct1Count, err := repo.CountAccountEvents(acct1.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, acct1Count)
|
||||||
|
|
||||||
|
acct2Count, err := repo.CountAccountEvents(acct2.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, acct2Count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func strPtr(s string) *string { return &s }
|
||||||
140
internal/repository/transaction_audit_log_test.go
Normal file
140
internal/repository/transaction_audit_log_test.go
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeTxAuditLog(t *testing.T, repo TransactionAuditLogRepository, transactionID string, action model.TransactionAuditAction, actorID *string, metadata map[string]any, ts time.Time) *model.TransactionAuditLog {
|
||||||
|
t.Helper()
|
||||||
|
var meta []byte
|
||||||
|
if metadata != nil {
|
||||||
|
var err error
|
||||||
|
meta, err = json.Marshal(metadata)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
entry := &model.TransactionAuditLog{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
TransactionID: transactionID,
|
||||||
|
ActorID: actorID,
|
||||||
|
Action: action,
|
||||||
|
Metadata: meta,
|
||||||
|
CreatedAt: ts,
|
||||||
|
}
|
||||||
|
require.NoError(t, repo.Create(entry))
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionAuditLogRepository_CreateAndListByTransaction(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
repo := NewTransactionAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
actor := testutil.CreateTestUserWithName(t, dbi.DB, "tx-audit@example.com", strPtr("Tx Actor"))
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Tx Audit Space")
|
||||||
|
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Acct")
|
||||||
|
txn := testutil.CreateTestTransaction(t, dbi.DB, account.ID, "Coffee", model.TransactionTypeWithdrawal, decimal.NewFromInt(5))
|
||||||
|
|
||||||
|
base := time.Now().Add(-time.Hour)
|
||||||
|
writeTxAuditLog(t, repo, txn.ID, model.TransactionAuditActionCreated, &actor.ID,
|
||||||
|
map[string]any{"account_id": account.ID, "transaction_type": "withdrawal", "title": "Coffee", "amount": "5.00"}, base)
|
||||||
|
writeTxAuditLog(t, repo, txn.ID, model.TransactionAuditActionEdited, &actor.ID,
|
||||||
|
map[string]any{"account_id": account.ID, "changes": map[string]any{"title": map[string]any{"old": "Coffee", "new": "Latte"}}}, base.Add(time.Minute))
|
||||||
|
|
||||||
|
count, err := repo.CountByTransaction(txn.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
logs, err := repo.ListByTransaction(txn.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 2)
|
||||||
|
// Newest first.
|
||||||
|
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
|
||||||
|
require.NotNil(t, logs[0].ActorName)
|
||||||
|
assert.Equal(t, "Tx Actor", *logs[0].ActorName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionAuditLogRepository_ListByAccount_LiveAndDeletedFallback(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
repo := NewTransactionAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
actor := testutil.CreateTestUser(t, dbi.DB, "acct-list@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Acct List Space")
|
||||||
|
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Acct")
|
||||||
|
|
||||||
|
// Live transaction with audit entry
|
||||||
|
live := testutil.CreateTestTransaction(t, dbi.DB, account.ID, "Live", model.TransactionTypeDeposit, decimal.NewFromInt(10))
|
||||||
|
writeTxAuditLog(t, repo, live.ID, model.TransactionAuditActionCreated, &actor.ID,
|
||||||
|
map[string]any{"account_id": account.ID}, time.Now().Add(-2*time.Minute))
|
||||||
|
|
||||||
|
// Audit entry referencing a transaction that no longer exists.
|
||||||
|
// Resolution must fall back to metadata.account_id.
|
||||||
|
ghostID := uuid.NewString()
|
||||||
|
writeTxAuditLog(t, repo, ghostID, model.TransactionAuditActionDeleted, &actor.ID,
|
||||||
|
map[string]any{"account_id": account.ID, "title": "Ghost"}, time.Now().Add(-time.Minute))
|
||||||
|
|
||||||
|
// Audit entry for a different account — must not appear.
|
||||||
|
other := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Other")
|
||||||
|
otherTxn := testutil.CreateTestTransaction(t, dbi.DB, other.ID, "Other", model.TransactionTypeDeposit, decimal.NewFromInt(1))
|
||||||
|
writeTxAuditLog(t, repo, otherTxn.ID, model.TransactionAuditActionCreated, &actor.ID,
|
||||||
|
map[string]any{"account_id": other.ID}, time.Now())
|
||||||
|
|
||||||
|
count, err := repo.CountByAccount(account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count, "should count live + ghost-via-metadata")
|
||||||
|
|
||||||
|
logs, err := repo.ListByAccount(account.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 2)
|
||||||
|
// Confirm both kinds present (one live, one ghost).
|
||||||
|
ids := []string{logs[0].TransactionID, logs[1].TransactionID}
|
||||||
|
assert.Contains(t, ids, live.ID)
|
||||||
|
assert.Contains(t, ids, ghostID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionAuditLogRepository_ListBySpace(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
repo := NewTransactionAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
actor := testutil.CreateTestUser(t, dbi.DB, "space-list@example.com", nil)
|
||||||
|
|
||||||
|
// Two spaces, each with an account and a transaction.
|
||||||
|
spaceA := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Space A")
|
||||||
|
acctA := testutil.CreateTestAccount(t, dbi.DB, spaceA.ID, "Acct A")
|
||||||
|
txnA := testutil.CreateTestTransaction(t, dbi.DB, acctA.ID, "txnA", model.TransactionTypeDeposit, decimal.NewFromInt(1))
|
||||||
|
writeTxAuditLog(t, repo, txnA.ID, model.TransactionAuditActionCreated, &actor.ID,
|
||||||
|
map[string]any{"account_id": acctA.ID}, time.Now().Add(-time.Minute))
|
||||||
|
|
||||||
|
spaceB := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Space B")
|
||||||
|
acctB := testutil.CreateTestAccount(t, dbi.DB, spaceB.ID, "Acct B")
|
||||||
|
txnB := testutil.CreateTestTransaction(t, dbi.DB, acctB.ID, "txnB", model.TransactionTypeDeposit, decimal.NewFromInt(1))
|
||||||
|
writeTxAuditLog(t, repo, txnB.ID, model.TransactionAuditActionCreated, &actor.ID,
|
||||||
|
map[string]any{"account_id": acctB.ID}, time.Now())
|
||||||
|
|
||||||
|
// Ghost in space A (deleted txn).
|
||||||
|
ghostID := uuid.NewString()
|
||||||
|
writeTxAuditLog(t, repo, ghostID, model.TransactionAuditActionDeleted, &actor.ID,
|
||||||
|
map[string]any{"account_id": acctA.ID}, time.Now().Add(-30*time.Second))
|
||||||
|
|
||||||
|
countA, err := repo.CountBySpace(spaceA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, countA)
|
||||||
|
|
||||||
|
countB, err := repo.CountBySpace(spaceB.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, countB)
|
||||||
|
|
||||||
|
logsA, err := repo.ListBySpace(spaceA.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logsA, 2)
|
||||||
|
})
|
||||||
|
}
|
||||||
9
internal/routes/main_test.go
Normal file
9
internal/routes/main_test.go
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) { testutil.PostgresMain(m) }
|
||||||
|
|
@ -16,7 +16,6 @@ import (
|
||||||
|
|
||||||
func newTestApp(dbi testutil.DBInfo) *app.App {
|
func newTestApp(dbi testutil.DBInfo) *app.App {
|
||||||
cfg := testutil.TestConfig()
|
cfg := testutil.TestConfig()
|
||||||
cfg.DBDriver = dbi.Driver
|
|
||||||
|
|
||||||
userRepo := repository.NewUserRepository(dbi.DB)
|
userRepo := repository.NewUserRepository(dbi.DB)
|
||||||
tokenRepo := repository.NewTokenRepository(dbi.DB)
|
tokenRepo := repository.NewTokenRepository(dbi.DB)
|
||||||
|
|
|
||||||
213
internal/service/account_activity_test.go
Normal file
213
internal/service/account_activity_test.go
Normal file
|
|
@ -0,0 +1,213 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubSpaceAuditRepo serves canned responses for the activity-merger tests so we can
|
||||||
|
// focus on the merge/sort/pagination logic without a real DB.
|
||||||
|
type stubSpaceAuditRepo struct {
|
||||||
|
listAccount []*model.SpaceAuditLogWithActor
|
||||||
|
countAccount int
|
||||||
|
listSpace []*model.SpaceAuditLogWithActor
|
||||||
|
countSpace int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubSpaceAuditRepo) Create(*model.SpaceAuditLog) error { return nil }
|
||||||
|
func (s *stubSpaceAuditRepo) ListBySpace(_ string, limit, _ int) ([]*model.SpaceAuditLogWithActor, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return firstN(s.listSpace, limit), nil
|
||||||
|
}
|
||||||
|
func (s *stubSpaceAuditRepo) CountBySpace(string) (int, error) { return s.countSpace, s.err }
|
||||||
|
func (s *stubSpaceAuditRepo) ListAccountEvents(_ string, limit, _ int) ([]*model.SpaceAuditLogWithActor, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return firstN(s.listAccount, limit), nil
|
||||||
|
}
|
||||||
|
func (s *stubSpaceAuditRepo) CountAccountEvents(string) (int, error) { return s.countAccount, s.err }
|
||||||
|
|
||||||
|
type stubTxAuditRepo struct {
|
||||||
|
listAccount []*model.TransactionAuditLogWithActor
|
||||||
|
countAccount int
|
||||||
|
listSpace []*model.TransactionAuditLogWithActor
|
||||||
|
countSpace int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubTxAuditRepo) Create(*model.TransactionAuditLog) error { return nil }
|
||||||
|
func (s *stubTxAuditRepo) ListByTransaction(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubTxAuditRepo) CountByTransaction(string) (int, error) { return 0, nil }
|
||||||
|
func (s *stubTxAuditRepo) ListByAccount(_ string, limit, _ int) ([]*model.TransactionAuditLogWithActor, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return firstNTx(s.listAccount, limit), nil
|
||||||
|
}
|
||||||
|
func (s *stubTxAuditRepo) CountByAccount(string) (int, error) { return s.countAccount, s.err }
|
||||||
|
func (s *stubTxAuditRepo) ListBySpace(_ string, limit, _ int) ([]*model.TransactionAuditLogWithActor, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return firstNTx(s.listSpace, limit), nil
|
||||||
|
}
|
||||||
|
func (s *stubTxAuditRepo) CountBySpace(string) (int, error) { return s.countSpace, s.err }
|
||||||
|
|
||||||
|
func firstN(s []*model.SpaceAuditLogWithActor, n int) []*model.SpaceAuditLogWithActor {
|
||||||
|
if n >= len(s) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:n]
|
||||||
|
}
|
||||||
|
func firstNTx(s []*model.TransactionAuditLogWithActor, n int) []*model.TransactionAuditLogWithActor {
|
||||||
|
if n >= len(s) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func spaceLog(action model.SpaceAuditAction, ts time.Time) *model.SpaceAuditLogWithActor {
|
||||||
|
return &model.SpaceAuditLogWithActor{
|
||||||
|
SpaceAuditLog: model.SpaceAuditLog{Action: action, CreatedAt: ts},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func txLog(action model.TransactionAuditAction, ts time.Time) *model.TransactionAuditLogWithActor {
|
||||||
|
return &model.TransactionAuditLogWithActor{
|
||||||
|
TransactionAuditLog: model.TransactionAuditLog{Action: action, CreatedAt: ts},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_List_MergesAndSortsByTimestamp(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
spaceRepo := &stubSpaceAuditRepo{
|
||||||
|
listAccount: []*model.SpaceAuditLogWithActor{
|
||||||
|
spaceLog(model.SpaceAuditActionAccountRenamed, now.Add(-1*time.Minute)),
|
||||||
|
spaceLog(model.SpaceAuditActionAccountCreated, now.Add(-10*time.Minute)),
|
||||||
|
},
|
||||||
|
countAccount: 2,
|
||||||
|
}
|
||||||
|
txRepo := &stubTxAuditRepo{
|
||||||
|
listAccount: []*model.TransactionAuditLogWithActor{
|
||||||
|
txLog(model.TransactionAuditActionEdited, now), // newest overall
|
||||||
|
txLog(model.TransactionAuditActionCreated, now.Add(-5*time.Minute)),
|
||||||
|
txLog(model.TransactionAuditActionDeleted, now.Add(-15*time.Minute)), // oldest overall
|
||||||
|
},
|
||||||
|
countAccount: 3,
|
||||||
|
}
|
||||||
|
svc := NewAccountActivityService(NewSpaceAuditLogService(spaceRepo), NewTransactionAuditLogService(txRepo))
|
||||||
|
|
||||||
|
rows, err := svc.List("acct-1", 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, rows, 5)
|
||||||
|
|
||||||
|
// Strictly descending by timestamp.
|
||||||
|
for i := 1; i < len(rows); i++ {
|
||||||
|
assert.False(t, rows[i].Timestamp().After(rows[i-1].Timestamp()),
|
||||||
|
"row %d (%v) is newer than row %d (%v)", i, rows[i].Timestamp(), i-1, rows[i-1].Timestamp())
|
||||||
|
}
|
||||||
|
// Top row is the transaction edit at `now`.
|
||||||
|
require.NotNil(t, rows[0].TxLog)
|
||||||
|
assert.Equal(t, model.TransactionAuditActionEdited, rows[0].TxLog.Action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_List_Pagination(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
spaceRepo := &stubSpaceAuditRepo{
|
||||||
|
listAccount: []*model.SpaceAuditLogWithActor{
|
||||||
|
spaceLog(model.SpaceAuditActionAccountCreated, now.Add(-30*time.Minute)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txRepo := &stubTxAuditRepo{
|
||||||
|
listAccount: []*model.TransactionAuditLogWithActor{
|
||||||
|
txLog(model.TransactionAuditActionEdited, now.Add(-10*time.Minute)),
|
||||||
|
txLog(model.TransactionAuditActionEdited, now.Add(-20*time.Minute)),
|
||||||
|
txLog(model.TransactionAuditActionEdited, now.Add(-40*time.Minute)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAccountActivityService(NewSpaceAuditLogService(spaceRepo), NewTransactionAuditLogService(txRepo))
|
||||||
|
|
||||||
|
page1, err := svc.List("a", 2, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, page1, 2)
|
||||||
|
|
||||||
|
page2, err := svc.List("a", 2, 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, page2, 2)
|
||||||
|
|
||||||
|
// Total of 4 entries; page2[1] is the oldest.
|
||||||
|
assert.Equal(t, now.Add(-40*time.Minute).Unix(), page2[1].Timestamp().Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_List_OffsetPastEndReturnsEmpty(t *testing.T) {
|
||||||
|
svc := NewAccountActivityService(
|
||||||
|
NewSpaceAuditLogService(&stubSpaceAuditRepo{}),
|
||||||
|
NewTransactionAuditLogService(&stubTxAuditRepo{}),
|
||||||
|
)
|
||||||
|
rows, err := svc.List("a", 10, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_Count_SumsBothSources(t *testing.T) {
|
||||||
|
svc := NewAccountActivityService(
|
||||||
|
NewSpaceAuditLogService(&stubSpaceAuditRepo{countAccount: 3}),
|
||||||
|
NewTransactionAuditLogService(&stubTxAuditRepo{countAccount: 7}),
|
||||||
|
)
|
||||||
|
count, err := svc.Count("a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_List_PropagatesError(t *testing.T) {
|
||||||
|
svc := NewAccountActivityService(
|
||||||
|
NewSpaceAuditLogService(&stubSpaceAuditRepo{err: errors.New("boom")}),
|
||||||
|
NewTransactionAuditLogService(&stubTxAuditRepo{}),
|
||||||
|
)
|
||||||
|
_, err := svc.List("a", 10, 0)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_ListSpace_MergesSpaceAndTxFeeds(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
spaceRepo := &stubSpaceAuditRepo{
|
||||||
|
listSpace: []*model.SpaceAuditLogWithActor{
|
||||||
|
spaceLog(model.SpaceAuditActionRenamed, now.Add(-3*time.Minute)),
|
||||||
|
spaceLog(model.SpaceAuditActionMemberInvited, now.Add(-5*time.Minute)),
|
||||||
|
},
|
||||||
|
countSpace: 2,
|
||||||
|
}
|
||||||
|
txRepo := &stubTxAuditRepo{
|
||||||
|
listSpace: []*model.TransactionAuditLogWithActor{
|
||||||
|
txLog(model.TransactionAuditActionCreated, now),
|
||||||
|
txLog(model.TransactionAuditActionEdited, now.Add(-4*time.Minute)),
|
||||||
|
},
|
||||||
|
countSpace: 2,
|
||||||
|
}
|
||||||
|
svc := NewAccountActivityService(NewSpaceAuditLogService(spaceRepo), NewTransactionAuditLogService(txRepo))
|
||||||
|
|
||||||
|
rows, err := svc.ListSpace("space", 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, rows, 4)
|
||||||
|
require.NotNil(t, rows[0].TxLog, "newest is the tx-created row at `now`")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountActivityService_CountSpace_SumsBothSources(t *testing.T) {
|
||||||
|
svc := NewAccountActivityService(
|
||||||
|
NewSpaceAuditLogService(&stubSpaceAuditRepo{countSpace: 4}),
|
||||||
|
NewTransactionAuditLogService(&stubTxAuditRepo{countSpace: 6}),
|
||||||
|
)
|
||||||
|
count, err := svc.CountSpace("s")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10, count)
|
||||||
|
}
|
||||||
112
internal/service/account_test.go
Normal file
112
internal/service/account_test.go
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/repository"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountService_CreateAccount_RecordsAudit(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
accountRepo := repository.NewAccountRepository(dbi.DB)
|
||||||
|
auditRepo := repository.NewSpaceAuditLogRepository(dbi.DB)
|
||||||
|
auditSvc := NewSpaceAuditLogService(auditRepo)
|
||||||
|
svc := NewAccountService(accountRepo)
|
||||||
|
svc.SetAuditLogger(auditSvc)
|
||||||
|
|
||||||
|
user := testutil.CreateTestUser(t, dbi.DB, "acct-create-audit@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||||
|
|
||||||
|
account, err := svc.CreateAccount(space.ID, "Checking", user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 1)
|
||||||
|
assert.Equal(t, model.SpaceAuditActionAccountCreated, logs[0].Action)
|
||||||
|
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
|
||||||
|
assert.Equal(t, account.ID, meta["account_id"])
|
||||||
|
assert.Equal(t, "Checking", meta["account_name"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountService_RenameAccount_RecordsAuditOnlyWhenChanged(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
accountRepo := repository.NewAccountRepository(dbi.DB)
|
||||||
|
auditRepo := repository.NewSpaceAuditLogRepository(dbi.DB)
|
||||||
|
svc := NewAccountService(accountRepo)
|
||||||
|
svc.SetAuditLogger(NewSpaceAuditLogService(auditRepo))
|
||||||
|
|
||||||
|
user := testutil.CreateTestUser(t, dbi.DB, "acct-rename-audit@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||||
|
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Old")
|
||||||
|
|
||||||
|
// Rename to a new value records an audit row.
|
||||||
|
require.NoError(t, svc.RenameAccount(account.ID, "New", user.ID))
|
||||||
|
|
||||||
|
// Renaming to the same value does not.
|
||||||
|
require.NoError(t, svc.RenameAccount(account.ID, "New", user.ID))
|
||||||
|
|
||||||
|
count, err := auditRepo.CountAccountEvents(account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
|
||||||
|
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
|
||||||
|
assert.Equal(t, "Old", meta["old_name"])
|
||||||
|
assert.Equal(t, "New", meta["new_name"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountService_DeleteAccount_RecordsAuditBeforeDelete(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
accountRepo := repository.NewAccountRepository(dbi.DB)
|
||||||
|
auditRepo := repository.NewSpaceAuditLogRepository(dbi.DB)
|
||||||
|
svc := NewAccountService(accountRepo)
|
||||||
|
svc.SetAuditLogger(NewSpaceAuditLogService(auditRepo))
|
||||||
|
|
||||||
|
user := testutil.CreateTestUser(t, dbi.DB, "acct-delete-audit@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||||
|
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Target")
|
||||||
|
|
||||||
|
require.NoError(t, svc.DeleteAccount(account.ID, user.ID))
|
||||||
|
|
||||||
|
// Account is gone.
|
||||||
|
_, err := accountRepo.ByID(account.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Audit row still exists and captured the pre-delete name (no FK on metadata).
|
||||||
|
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 1)
|
||||||
|
assert.Equal(t, model.SpaceAuditActionAccountDeleted, logs[0].Action)
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
|
||||||
|
assert.Equal(t, "Target", meta["account_name"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountService_NoAuditLoggerSet_DoesNotPanic(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
// SetAuditLogger is intentionally optional so existing tests/callers that
|
||||||
|
// don't care about audit don't have to wire it.
|
||||||
|
svc := NewAccountService(repository.NewAccountRepository(dbi.DB))
|
||||||
|
|
||||||
|
user := testutil.CreateTestUser(t, dbi.DB, "no-audit@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||||
|
|
||||||
|
account, err := svc.CreateAccount(space.ID, "x", user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, svc.RenameAccount(account.ID, "y", user.ID))
|
||||||
|
require.NoError(t, svc.DeleteAccount(account.ID, user.ID))
|
||||||
|
})
|
||||||
|
}
|
||||||
9
internal/service/main_test.go
Normal file
9
internal/service/main_test.go
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) { testutil.PostgresMain(m) }
|
||||||
100
internal/service/space_audit_log_test.go
Normal file
100
internal/service/space_audit_log_test.go
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeSpaceAuditRepo struct {
|
||||||
|
created []*model.SpaceAuditLog
|
||||||
|
failNext error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSpaceAuditRepo) Create(log *model.SpaceAuditLog) error {
|
||||||
|
if f.failNext != nil {
|
||||||
|
err := f.failNext
|
||||||
|
f.failNext = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.created = append(f.created, log)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (f *fakeSpaceAuditRepo) ListBySpace(string, int, int) ([]*model.SpaceAuditLogWithActor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeSpaceAuditRepo) CountBySpace(string) (int, error) { return 0, nil }
|
||||||
|
func (f *fakeSpaceAuditRepo) ListAccountEvents(string, int, int) ([]*model.SpaceAuditLogWithActor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeSpaceAuditRepo) CountAccountEvents(string) (int, error) { return 0, nil }
|
||||||
|
|
||||||
|
func TestSpaceAuditLogService_Record_PersistsEntry(t *testing.T) {
|
||||||
|
repo := &fakeSpaceAuditRepo{}
|
||||||
|
svc := NewSpaceAuditLogService(repo)
|
||||||
|
|
||||||
|
svc.Record(RecordOptions{
|
||||||
|
SpaceID: "space-1",
|
||||||
|
ActorID: "actor-1",
|
||||||
|
Action: model.SpaceAuditActionRenamed,
|
||||||
|
TargetUserID: "target-1",
|
||||||
|
TargetEmail: "target@example.com",
|
||||||
|
Metadata: map[string]any{"old_name": "A", "new_name": "B"},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Len(t, repo.created, 1)
|
||||||
|
got := repo.created[0]
|
||||||
|
assert.Equal(t, "space-1", got.SpaceID)
|
||||||
|
require.NotNil(t, got.ActorID)
|
||||||
|
assert.Equal(t, "actor-1", *got.ActorID)
|
||||||
|
require.NotNil(t, got.TargetUserID)
|
||||||
|
assert.Equal(t, "target-1", *got.TargetUserID)
|
||||||
|
require.NotNil(t, got.TargetEmail)
|
||||||
|
assert.Equal(t, "target@example.com", *got.TargetEmail)
|
||||||
|
assert.Equal(t, model.SpaceAuditActionRenamed, got.Action)
|
||||||
|
assert.NotEmpty(t, got.ID)
|
||||||
|
assert.False(t, got.CreatedAt.IsZero())
|
||||||
|
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(got.Metadata, &meta))
|
||||||
|
assert.Equal(t, "A", meta["old_name"])
|
||||||
|
assert.Equal(t, "B", meta["new_name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpaceAuditLogService_Record_OmitsBlankOptionalFields(t *testing.T) {
|
||||||
|
repo := &fakeSpaceAuditRepo{}
|
||||||
|
svc := NewSpaceAuditLogService(repo)
|
||||||
|
|
||||||
|
svc.Record(RecordOptions{
|
||||||
|
SpaceID: "space-1",
|
||||||
|
Action: model.SpaceAuditActionDeleted,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Len(t, repo.created, 1)
|
||||||
|
got := repo.created[0]
|
||||||
|
assert.Nil(t, got.ActorID)
|
||||||
|
assert.Nil(t, got.TargetUserID)
|
||||||
|
assert.Nil(t, got.TargetEmail)
|
||||||
|
assert.Empty(t, got.Metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpaceAuditLogService_Record_SwallowsRepoError(t *testing.T) {
|
||||||
|
// Audit failures must not bubble up to break the user's action.
|
||||||
|
repo := &fakeSpaceAuditRepo{failNext: errors.New("boom")}
|
||||||
|
svc := NewSpaceAuditLogService(repo)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
svc.Record(RecordOptions{SpaceID: "s", Action: model.SpaceAuditActionRenamed})
|
||||||
|
})
|
||||||
|
assert.Empty(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpaceAuditLogService_Record_NilReceiverIsNoOp(t *testing.T) {
|
||||||
|
var svc *SpaceAuditLogService
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
svc.Record(RecordOptions{SpaceID: "s", Action: model.SpaceAuditActionRenamed})
|
||||||
|
})
|
||||||
|
}
|
||||||
76
internal/service/transaction_audit_log_test.go
Normal file
76
internal/service/transaction_audit_log_test.go
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTxAuditRepo struct {
|
||||||
|
created []*model.TransactionAuditLog
|
||||||
|
failNext error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTxAuditRepo) Create(log *model.TransactionAuditLog) error {
|
||||||
|
if f.failNext != nil {
|
||||||
|
err := f.failNext
|
||||||
|
f.failNext = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.created = append(f.created, log)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (f *fakeTxAuditRepo) ListByTransaction(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeTxAuditRepo) CountByTransaction(string) (int, error) { return 0, nil }
|
||||||
|
func (f *fakeTxAuditRepo) ListByAccount(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeTxAuditRepo) CountByAccount(string) (int, error) { return 0, nil }
|
||||||
|
func (f *fakeTxAuditRepo) ListBySpace(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeTxAuditRepo) CountBySpace(string) (int, error) { return 0, nil }
|
||||||
|
|
||||||
|
func TestTransactionAuditLogService_Record_PersistsEntry(t *testing.T) {
|
||||||
|
repo := &fakeTxAuditRepo{}
|
||||||
|
svc := NewTransactionAuditLogService(repo)
|
||||||
|
|
||||||
|
svc.Record(TransactionRecordOptions{
|
||||||
|
TransactionID: "txn-1",
|
||||||
|
ActorID: "actor-1",
|
||||||
|
Action: model.TransactionAuditActionEdited,
|
||||||
|
Metadata: map[string]any{"changes": map[string]any{"title": "x"}},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Len(t, repo.created, 1)
|
||||||
|
got := repo.created[0]
|
||||||
|
assert.Equal(t, "txn-1", got.TransactionID)
|
||||||
|
require.NotNil(t, got.ActorID)
|
||||||
|
assert.Equal(t, "actor-1", *got.ActorID)
|
||||||
|
assert.Equal(t, model.TransactionAuditActionEdited, got.Action)
|
||||||
|
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(got.Metadata, &meta))
|
||||||
|
assert.Contains(t, meta, "changes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionAuditLogService_Record_SwallowsRepoError(t *testing.T) {
|
||||||
|
repo := &fakeTxAuditRepo{failNext: errors.New("boom")}
|
||||||
|
svc := NewTransactionAuditLogService(repo)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
svc.Record(TransactionRecordOptions{TransactionID: "x", Action: model.TransactionAuditActionEdited})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionAuditLogService_Record_NilReceiverIsNoOp(t *testing.T) {
|
||||||
|
var svc *TransactionAuditLogService
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
svc.Record(TransactionRecordOptions{TransactionID: "x", Action: model.TransactionAuditActionEdited})
|
||||||
|
})
|
||||||
|
}
|
||||||
265
internal/service/transaction_test.go
Normal file
265
internal/service/transaction_test.go
Normal file
|
|
@ -0,0 +1,265 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/repository"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// txnFixture builds a fully wired TransactionService against a real DB along with
|
||||||
|
// the helper repos the tests need to inspect post-state.
|
||||||
|
type txnFixture struct {
|
||||||
|
svc *TransactionService
|
||||||
|
txAudit repository.TransactionAuditLogRepository
|
||||||
|
accounts repository.AccountRepository
|
||||||
|
user *model.User
|
||||||
|
account *model.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTxnFixture(t *testing.T, dbi testutil.DBInfo) *txnFixture {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
txnRepo := repository.NewTransactionRepository(dbi.DB)
|
||||||
|
categoryRepo := repository.NewCategoryRepository(dbi.DB)
|
||||||
|
accountRepo := repository.NewAccountRepository(dbi.DB)
|
||||||
|
auditRepo := repository.NewTransactionAuditLogRepository(dbi.DB)
|
||||||
|
|
||||||
|
accountSvc := NewAccountService(accountRepo)
|
||||||
|
auditSvc := NewTransactionAuditLogService(auditRepo)
|
||||||
|
svc := NewTransactionService(txnRepo, categoryRepo, accountSvc)
|
||||||
|
svc.SetAuditLogger(auditSvc)
|
||||||
|
|
||||||
|
user := testutil.CreateTestUser(t, dbi.DB, t.Name()+"@example.com", nil)
|
||||||
|
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||||
|
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Acct")
|
||||||
|
|
||||||
|
return &txnFixture{
|
||||||
|
svc: svc,
|
||||||
|
txAudit: auditRepo,
|
||||||
|
accounts: accountRepo,
|
||||||
|
user: user,
|
||||||
|
account: account,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_Deposit_RecordsAuditAndUpdatesBalance(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
|
||||||
|
txn, err := f.svc.Deposit(DepositInput{
|
||||||
|
AccountID: f.account.ID,
|
||||||
|
Title: "Paycheck",
|
||||||
|
Amount: decimal.NewFromInt(100),
|
||||||
|
OccurredAt: time.Now(),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, decimal.NewFromInt(100).Equal(txn.Value))
|
||||||
|
|
||||||
|
updated, err := f.accounts.ByID(f.account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, decimal.NewFromInt(100).Equal(updated.Balance))
|
||||||
|
|
||||||
|
logs, err := f.txAudit.ListByTransaction(txn.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 1)
|
||||||
|
assert.Equal(t, model.TransactionAuditActionCreated, logs[0].Action)
|
||||||
|
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
|
||||||
|
assert.Equal(t, "deposit", meta["transaction_type"])
|
||||||
|
assert.Equal(t, f.account.ID, meta["account_id"])
|
||||||
|
assert.Equal(t, "Paycheck", meta["title"])
|
||||||
|
assert.Equal(t, "100.00", meta["amount"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_PayBill_RecordsAuditAndDebitsBalance(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
|
||||||
|
// Seed some balance via deposit.
|
||||||
|
_, err := f.svc.Deposit(DepositInput{
|
||||||
|
AccountID: f.account.ID, Title: "seed", Amount: decimal.NewFromInt(50), OccurredAt: time.Now(), ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
txn, err := f.svc.PayBill(PayBillInput{
|
||||||
|
AccountID: f.account.ID,
|
||||||
|
Title: "Rent",
|
||||||
|
Amount: decimal.NewFromInt(20),
|
||||||
|
OccurredAt: time.Now(),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := f.accounts.ByID(f.account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, decimal.NewFromInt(30).Equal(updated.Balance))
|
||||||
|
|
||||||
|
logs, err := f.txAudit.ListByTransaction(txn.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 1)
|
||||||
|
var meta map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
|
||||||
|
assert.Equal(t, "withdrawal", meta["transaction_type"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_UpdateDeposit_RebalancesAndDiffs(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
|
||||||
|
original, err := f.svc.Deposit(DepositInput{
|
||||||
|
AccountID: f.account.ID, Title: "Old", Amount: decimal.NewFromInt(40),
|
||||||
|
OccurredAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = f.svc.UpdateDeposit(UpdateDepositInput{
|
||||||
|
TransactionID: original.ID,
|
||||||
|
Title: "New",
|
||||||
|
Amount: decimal.NewFromInt(60), // +20 net
|
||||||
|
OccurredAt: time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Balance reflects the swap (40 → 60 means +20 from 40 baseline).
|
||||||
|
updated, err := f.accounts.ByID(f.account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, decimal.NewFromInt(60).Equal(updated.Balance))
|
||||||
|
|
||||||
|
// 2 audit rows: created + edited (newest first).
|
||||||
|
logs, err := f.txAudit.ListByTransaction(original.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 2)
|
||||||
|
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
|
||||||
|
|
||||||
|
var meta struct {
|
||||||
|
AccountID string `json:"account_id"`
|
||||||
|
Changes map[string]map[string]any `json:"changes"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
|
||||||
|
assert.Equal(t, f.account.ID, meta.AccountID)
|
||||||
|
assert.Contains(t, meta.Changes, "title")
|
||||||
|
assert.Equal(t, "Old", meta.Changes["title"]["old"])
|
||||||
|
assert.Equal(t, "New", meta.Changes["title"]["new"])
|
||||||
|
assert.Contains(t, meta.Changes, "amount")
|
||||||
|
assert.Equal(t, "40.00", meta.Changes["amount"]["old"])
|
||||||
|
assert.Equal(t, "60.00", meta.Changes["amount"]["new"])
|
||||||
|
assert.Contains(t, meta.Changes, "occurred_at")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_UpdateDeposit_NoChanges_NoAudit(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
|
||||||
|
original, err := f.svc.Deposit(DepositInput{
|
||||||
|
AccountID: f.account.ID, Title: "Same", Amount: decimal.NewFromInt(10),
|
||||||
|
OccurredAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Update with identical values.
|
||||||
|
_, err = f.svc.UpdateDeposit(UpdateDepositInput{
|
||||||
|
TransactionID: original.ID,
|
||||||
|
Title: "Same",
|
||||||
|
Amount: decimal.NewFromInt(10),
|
||||||
|
OccurredAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Only the original `created` audit row exists; no `edited` row.
|
||||||
|
count, err := f.txAudit.CountByTransaction(original.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_UpdateBill_RebalancesAndDiffs(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
|
||||||
|
// Seed funds and then a bill.
|
||||||
|
_, err := f.svc.Deposit(DepositInput{
|
||||||
|
AccountID: f.account.ID, Title: "seed", Amount: decimal.NewFromInt(100), OccurredAt: time.Now(), ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
bill, err := f.svc.PayBill(PayBillInput{
|
||||||
|
AccountID: f.account.ID, Title: "Cable", Amount: decimal.NewFromInt(30), OccurredAt: time.Now(), ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = f.svc.UpdateBill(UpdateBillInput{
|
||||||
|
TransactionID: bill.ID,
|
||||||
|
Title: "Internet",
|
||||||
|
Amount: decimal.NewFromInt(40), // -10 vs original
|
||||||
|
OccurredAt: bill.OccurredAt,
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 100 - 40 = 60.
|
||||||
|
updated, err := f.accounts.ByID(f.account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, decimal.NewFromInt(60).Equal(updated.Balance))
|
||||||
|
|
||||||
|
logs, err := f.txAudit.ListByTransaction(bill.ID, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, logs, 2)
|
||||||
|
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_UpdateDeposit_RejectsBillTransaction(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
_, err := f.svc.Deposit(DepositInput{
|
||||||
|
AccountID: f.account.ID, Title: "seed", Amount: decimal.NewFromInt(50), OccurredAt: time.Now(), ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
bill, err := f.svc.PayBill(PayBillInput{
|
||||||
|
AccountID: f.account.ID, Title: "Bill", Amount: decimal.NewFromInt(10), OccurredAt: time.Now(), ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = f.svc.UpdateDeposit(UpdateDepositInput{
|
||||||
|
TransactionID: bill.ID,
|
||||||
|
Title: "x",
|
||||||
|
Amount: decimal.NewFromInt(1),
|
||||||
|
OccurredAt: time.Now(),
|
||||||
|
ActorID: f.user.ID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionService_Validations(t *testing.T) {
|
||||||
|
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||||
|
f := newTxnFixture(t, dbi)
|
||||||
|
|
||||||
|
_, err := f.svc.Deposit(DepositInput{AccountID: f.account.ID, Amount: decimal.NewFromInt(1), OccurredAt: time.Now()})
|
||||||
|
assert.Error(t, err, "blank title")
|
||||||
|
|
||||||
|
_, err = f.svc.Deposit(DepositInput{AccountID: f.account.ID, Title: "x", Amount: decimal.NewFromInt(0), OccurredAt: time.Now()})
|
||||||
|
assert.Error(t, err, "zero amount")
|
||||||
|
|
||||||
|
_, err = f.svc.Deposit(DepositInput{AccountID: f.account.ID, Title: "x", Amount: decimal.NewFromInt(1)})
|
||||||
|
assert.Error(t, err, "missing date")
|
||||||
|
|
||||||
|
_, err = f.svc.PayBill(PayBillInput{Title: "x", Amount: decimal.NewFromInt(1), OccurredAt: time.Now()})
|
||||||
|
assert.Error(t, err, "missing account id")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -23,8 +23,7 @@ func TestConfig() *config.Config {
|
||||||
AppURL: "http://localhost:9999",
|
AppURL: "http://localhost:9999",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: "9999",
|
Port: "9999",
|
||||||
DBDriver: "sqlite",
|
DBConnection: "",
|
||||||
DBConnection: ":memory:",
|
|
||||||
JWTSecret: "test-secret-key-for-testing-only",
|
JWTSecret: "test-secret-key-for-testing-only",
|
||||||
JWTExpiry: 24 * time.Hour,
|
JWTExpiry: 24 * time.Hour,
|
||||||
TokenMagicLinkExpiry: 10 * time.Minute,
|
TokenMagicLinkExpiry: 10 * time.Minute,
|
||||||
|
|
|
||||||
118
internal/testutil/postgres_main.go
Normal file
118
internal/testutil/postgres_main.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgresMain is the TestMain entry point used by every package whose tests touch
|
||||||
|
// the database. It guarantees a running PostgreSQL 17 instance for the duration of
|
||||||
|
// the test binary:
|
||||||
|
//
|
||||||
|
// - If BUDGIT_TEST_POSTGRES_URL is already set, it is used as-is. CI and `task test`
|
||||||
|
// hit this path.
|
||||||
|
// - Otherwise an ephemeral `postgres:17-alpine` container is started on a free
|
||||||
|
// local port, BUDGIT_TEST_POSTGRES_URL is exported to it for the test process,
|
||||||
|
// and the container is removed when the test binary exits — even on panic, via
|
||||||
|
// a deferred cleanup around m.Run().
|
||||||
|
//
|
||||||
|
// Usage in each test package:
|
||||||
|
//
|
||||||
|
// func TestMain(m *testing.M) { testutil.PostgresMain(m) }
|
||||||
|
func PostgresMain(m *testing.M) {
|
||||||
|
if os.Getenv("BUDGIT_TEST_POSTGRES_URL") != "" {
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("docker"); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "testutil.PostgresMain: BUDGIT_TEST_POSTGRES_URL is unset and `docker` is not on PATH; cannot run db tests")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := freePort()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: failed to find free port: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
containerName := fmt.Sprintf("budgit-test-pg-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||||
|
|
||||||
|
startCmd := exec.Command("docker", "run", "--rm", "-d",
|
||||||
|
"--name", containerName,
|
||||||
|
"-p", fmt.Sprintf("%d:5432", port),
|
||||||
|
"-e", "POSTGRES_USER=budgit_test",
|
||||||
|
"-e", "POSTGRES_PASSWORD=testpass",
|
||||||
|
"-e", "POSTGRES_DB=budgit_test",
|
||||||
|
// tmpfs for the data dir keeps tests fast — we don't care about durability.
|
||||||
|
"--tmpfs", "/var/lib/postgresql/data:rw",
|
||||||
|
"postgres:17-alpine",
|
||||||
|
)
|
||||||
|
if out, err := startCmd.CombinedOutput(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: docker run failed: %v\n%s\n", err, out)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
stop := func() {
|
||||||
|
// `docker rm -f` because --rm only fires on a clean exit; force-stop the
|
||||||
|
// container regardless of state so leftover containers don't accumulate.
|
||||||
|
_ = exec.Command("docker", "rm", "-f", containerName).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("postgres://budgit_test:testpass@127.0.0.1:%d/budgit_test?sslmode=disable", port)
|
||||||
|
if err := waitForPostgres(url, 60*time.Second); err != nil {
|
||||||
|
stop()
|
||||||
|
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: postgres did not become ready: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Setenv("BUDGIT_TEST_POSTGRES_URL", url); err != nil {
|
||||||
|
stop()
|
||||||
|
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: setenv failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run tests, then ALWAYS stop the container — including on panic.
|
||||||
|
code := func() int {
|
||||||
|
defer stop()
|
||||||
|
return m.Run()
|
||||||
|
}()
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func freePort() (int, error) {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
return l.Addr().(*net.TCPAddr).Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForPostgres polls until a real client connection succeeds. pg_isready isn't
|
||||||
|
// sufficient under load — under parallel `go test ./...` we've seen it report ready
|
||||||
|
// while client connections still fail with "unexpected EOF" because the server is
|
||||||
|
// still finishing startup.
|
||||||
|
func waitForPostgres(url string, timeout time.Duration) error {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
var lastErr error
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
db, err := sql.Open("pgx", url)
|
||||||
|
if err == nil {
|
||||||
|
if err = db.Ping(); err == nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_ = db.Close()
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("timed out after %s: %w", timeout, lastErr)
|
||||||
|
}
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"git.juancwu.dev/juancwu/budgit/internal/model"
|
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateTestUser inserts a user directly into the database.
|
// CreateTestUser inserts a user directly into the database.
|
||||||
|
|
@ -79,6 +80,52 @@ func CreateTestSpace(t *testing.T, db *sqlx.DB, ownerID, name string) *model.Spa
|
||||||
return space
|
return space
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateTestAccount inserts an account directly into the database.
|
||||||
|
func CreateTestAccount(t *testing.T, db *sqlx.DB, spaceID, name string) *model.Account {
|
||||||
|
t.Helper()
|
||||||
|
now := time.Now()
|
||||||
|
account := &model.Account{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
Name: name,
|
||||||
|
SpaceID: spaceID,
|
||||||
|
Balance: decimal.Zero,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO accounts (id, name, space_id, balance, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||||
|
account.ID, account.Name, account.SpaceID, account.Balance, account.CreatedAt, account.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTestAccount: %v", err)
|
||||||
|
}
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTestTransaction inserts a transaction directly into the database.
|
||||||
|
func CreateTestTransaction(t *testing.T, db *sqlx.DB, accountID, title string, txnType model.TransactionType, amount decimal.Decimal) *model.Transaction {
|
||||||
|
t.Helper()
|
||||||
|
now := time.Now()
|
||||||
|
txn := &model.Transaction{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
Value: amount,
|
||||||
|
Type: txnType,
|
||||||
|
AccountID: accountID,
|
||||||
|
Title: title,
|
||||||
|
OccurredAt: now,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO transactions (id, value, type, account_id, title, description, occurred_at, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||||
|
txn.ID, txn.Value, txn.Type, txn.AccountID, txn.Title, txn.Description, txn.OccurredAt, txn.CreatedAt, txn.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTestTransaction: %v", err)
|
||||||
|
}
|
||||||
|
return txn
|
||||||
|
}
|
||||||
|
|
||||||
// CreateTestToken inserts a token directly into the database.
|
// CreateTestToken inserts a token directly into the database.
|
||||||
func CreateTestToken(t *testing.T, db *sqlx.DB, userID, tokenType, tokenString string, expiresAt time.Time) *model.Token {
|
func CreateTestToken(t *testing.T, db *sqlx.DB, userID, tokenType, tokenString string, expiresAt time.Time) *model.Token {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
|
||||||
|
|
@ -10,26 +10,24 @@ import (
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DBInfo holds a test database connection and its driver name.
|
// DBInfo holds a test database connection.
|
||||||
type DBInfo struct {
|
type DBInfo struct {
|
||||||
DB *sqlx.DB
|
DB *sqlx.DB
|
||||||
Driver string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForEachDB runs the given test function against both SQLite and PostgreSQL.
|
// ForEachDB runs the test function against PostgreSQL. Skips when
|
||||||
// PostgreSQL tests are skipped when BUDGIT_TEST_POSTGRES_URL is unset.
|
// BUDGIT_TEST_POSTGRES_URL is not set so quick local runs don't fail
|
||||||
|
// without a database. CI must always set it.
|
||||||
|
//
|
||||||
|
// Each test gets its own schema for isolation; the schema is dropped on
|
||||||
|
// cleanup. The function name is preserved for backwards compatibility,
|
||||||
|
// although there is now only one engine.
|
||||||
func ForEachDB(t *testing.T, fn func(t *testing.T, dbi DBInfo)) {
|
func ForEachDB(t *testing.T, fn func(t *testing.T, dbi DBInfo)) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
t.Run("sqlite", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
dbi := newSQLiteDB(t)
|
|
||||||
fn(t, dbi)
|
|
||||||
})
|
|
||||||
|
|
||||||
pgURL := os.Getenv("BUDGIT_TEST_POSTGRES_URL")
|
pgURL := os.Getenv("BUDGIT_TEST_POSTGRES_URL")
|
||||||
if pgURL == "" {
|
if pgURL == "" {
|
||||||
t.Log("skipping postgres tests: BUDGIT_TEST_POSTGRES_URL not set")
|
t.Skip("skipping db tests: BUDGIT_TEST_POSTGRES_URL not set")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -40,55 +38,25 @@ func ForEachDB(t *testing.T, fn func(t *testing.T, dbi DBInfo)) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSQLiteDB(t *testing.T) DBInfo {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
// Use a unique in-memory database per test via a unique DSN.
|
|
||||||
// Each file::memory:?cache=shared&name=X uses a separate in-memory DB.
|
|
||||||
safeName := strings.ReplaceAll(t.Name(), "/", "_")
|
|
||||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_pragma=foreign_keys(1)", safeName)
|
|
||||||
|
|
||||||
sqliteDB, err := sqlx.Connect("sqlite", dsn)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to connect to sqlite: %v", err)
|
|
||||||
}
|
|
||||||
// SQLite in-memory DBs are destroyed when the last connection closes.
|
|
||||||
// Keep at least one open so it survives the test.
|
|
||||||
sqliteDB.SetMaxOpenConns(1)
|
|
||||||
|
|
||||||
t.Cleanup(func() { sqliteDB.Close() })
|
|
||||||
|
|
||||||
err = db.RunMigrations(sqliteDB.DB, "sqlite")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to run sqlite migrations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return DBInfo{DB: sqliteDB, Driver: "sqlite"}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPostgresDB(t *testing.T, baseURL string) DBInfo {
|
func newPostgresDB(t *testing.T, baseURL string) DBInfo {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Create a unique schema per test to ensure isolation.
|
// Create a unique schema per test for isolation.
|
||||||
safeName := strings.ReplaceAll(t.Name(), "/", "_")
|
safeName := strings.ReplaceAll(t.Name(), "/", "_")
|
||||||
safeName = strings.ReplaceAll(safeName, " ", "_")
|
safeName = strings.ReplaceAll(safeName, " ", "_")
|
||||||
schema := fmt.Sprintf("test_%s", safeName)
|
schema := fmt.Sprintf("test_%s", safeName)
|
||||||
|
|
||||||
// Connect to the base database to create the schema.
|
|
||||||
baseDB, err := sqlx.Connect("pgx", baseURL)
|
baseDB, err := sqlx.Connect("pgx", baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to postgres: %v", err)
|
t.Fatalf("failed to connect to postgres: %v", err)
|
||||||
}
|
}
|
||||||
|
if _, err := baseDB.Exec(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q", schema)); err != nil {
|
||||||
_, err = baseDB.Exec(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q", schema))
|
|
||||||
if err != nil {
|
|
||||||
baseDB.Close()
|
baseDB.Close()
|
||||||
t.Fatalf("failed to create schema %s: %v", schema, err)
|
t.Fatalf("failed to create schema %s: %v", schema, err)
|
||||||
}
|
}
|
||||||
baseDB.Close()
|
baseDB.Close()
|
||||||
|
|
||||||
// Connect with a single-connection pool and set search_path to the new schema.
|
// MaxOpenConns(1) ensures every query reuses the connection where
|
||||||
// MaxOpenConns(1) ensures all queries reuse the same connection where
|
|
||||||
// search_path is set (SET is session-level in PostgreSQL).
|
// search_path is set (SET is session-level in PostgreSQL).
|
||||||
pgDB, err := sqlx.Connect("pgx", baseURL)
|
pgDB, err := sqlx.Connect("pgx", baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -96,15 +64,13 @@ func newPostgresDB(t *testing.T, baseURL string) DBInfo {
|
||||||
}
|
}
|
||||||
pgDB.SetMaxOpenConns(1)
|
pgDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
_, err = pgDB.Exec(fmt.Sprintf(`SET search_path TO "%s"`, schema))
|
if _, err := pgDB.Exec(fmt.Sprintf(`SET search_path TO "%s"`, schema)); err != nil {
|
||||||
if err != nil {
|
|
||||||
pgDB.Close()
|
pgDB.Close()
|
||||||
t.Fatalf("failed to set search_path to %s: %v", schema, err)
|
t.Fatalf("failed to set search_path to %s: %v", schema, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
pgDB.Close()
|
pgDB.Close()
|
||||||
// Drop the schema after the test.
|
|
||||||
cleanDB, err := sqlx.Connect("pgx", baseURL)
|
cleanDB, err := sqlx.Connect("pgx", baseURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %q CASCADE", schema))
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %q CASCADE", schema))
|
||||||
|
|
@ -112,10 +78,9 @@ func newPostgresDB(t *testing.T, baseURL string) DBInfo {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = db.RunMigrations(pgDB.DB, "pgx")
|
if err := db.RunMigrations(pgDB.DB); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to run postgres migrations: %v", err)
|
t.Fatalf("failed to run postgres migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return DBInfo{DB: pgDB, Driver: "pgx"}
|
return DBInfo{DB: pgDB}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue