feat: drop sqlite support
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m27s

This commit is contained in:
juancwu 2026-05-04 00:29:45 +00:00
commit da718427bd
27 changed files with 1296 additions and 115 deletions

View file

@ -5,8 +5,8 @@ APP_URL=http://127.0.0.1:7331 # required for base url in email links etc, port i
HOST=127.0.0.1
PORT=9000
DB_DRIVER=sqlite
DB_CONNECTION="./data/local.db?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)"
# PostgreSQL 17 connection string (libpq URL or DSN). Required.
DB_CONNECTION="postgres://budgit:budgit@127.0.0.1:5432/budgit?sslmode=disable"
JWT_SECRET=
# Go duration format

View file

@ -64,11 +64,11 @@ Components live in `internal/ui/components/` — button, input, checkbox, dialog
## Configuration
App reads from `.env` file via `godotenv`. Key vars: `APP_ENV`, `APP_URL`, `DB_DRIVER` (pgx/sqlite), `DB_CONNECTION`, `JWT_SECRET`, `PORT`. See `internal/config/config.go` for all fields.
App reads from `.env` file via `godotenv`. Key vars: `APP_ENV`, `APP_URL`, `DB_CONNECTION` (libpq URL/DSN), `JWT_SECRET`, `PORT`. See `internal/config/config.go` for all fields.
## Database
PostgreSQL (pgx driver) or SQLite. Migrations auto-run on startup from `internal/db/migrations/` (Goose SQL format, embedded via `go:embed`). 8 migration files covering users, tokens, profiles, spaces, shopping lists, tags, expenses, invitations.
PostgreSQL 17 only (pgx driver). Migrations auto-run on startup from `internal/db/migrations/` (Goose SQL format, embedded via `go:embed`). Tests run against an ephemeral Postgres container via `task test`; set `BUDGIT_TEST_POSTGRES_URL` to point at a long-lived instance instead.
# templui Components

View file

@ -26,18 +26,13 @@ tasks:
cmds:
- echo "Starting app..."
- task --parallel tailwind-watch templ
# Testing
# Testing — TestMain in each db-touching package auto-spins an ephemeral
# postgres:17-alpine container; honors BUDGIT_TEST_POSTGRES_URL when set
# (CI uses that path to point at a long-lived service).
test:
desc: Run tests (SQLite only)
desc: Run tests (auto-starts an ephemeral PostgreSQL 17 container if needed)
cmds:
- set -o pipefail && go test ./... -json | tparse -all
test:integration:
desc: Run tests against both SQLite and PostgreSQL
cmds:
- docker run --name budgit-test-pg -d -p 5433:5432 -e POSTGRES_USER=budgit_test -e POSTGRES_PASSWORD=testpass -e POSTGRES_DB=budgit_test postgres:17-alpine
- defer: docker rm -f budgit-test-pg
- cmd: sleep 3
- cmd: BUDGIT_TEST_POSTGRES_URL="postgres://budgit_test:testpass@localhost:5433/budgit_test?sslmode=disable" set -o pipefail && go test ./... -json | tparse -all
# Production build
build:
desc: Build production binary

4
go.mod
View file

@ -15,8 +15,8 @@ require (
github.com/pressly/goose/v3 v3.26.0
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.11.1
github.com/templui/templui v1.9.5
github.com/wneessen/go-mail v0.7.2
modernc.org/sqlite v1.40.1
)
require (
@ -62,7 +62,6 @@ require (
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/templui/templui v1.9.5 // indirect
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d // indirect
github.com/vertica/vertica-sql-go v1.3.3 // indirect
github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 // indirect
@ -87,6 +86,7 @@ require (
modernc.org/libc v1.66.10 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.40.1 // indirect
)
tool (

View file

@ -26,12 +26,12 @@ type App struct {
}
func New(cfg *config.Config) (*App, error) {
database, err := db.Init(cfg.DBDriver, cfg.DBConnection)
database, err := db.Init(cfg.DBConnection)
if err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err)
}
err = db.RunMigrations(database.DB, cfg.DBDriver)
err = db.RunMigrations(database.DB)
if err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}

View file

@ -17,7 +17,6 @@ type Config struct {
Host string
Port string
DBDriver string
DBConnection string
JWTSecret string
@ -53,8 +52,7 @@ func Load(version string) *Config {
Host: envString("HOST", "127.0.0.1"),
Port: envString("PORT", "9000"),
DBDriver: envString("DB_DRIVER", "sqlite"),
DBConnection: envString("DB_CONNECTION", "./data/local.db?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)"),
DBConnection: envRequired("DB_CONNECTION"),
JWTSecret: envRequired("JWT_SECRET"),
JWTExpiry: envDuration("JWT_EXPIRY", 168*time.Hour), // 7 days default

View file

@ -3,25 +3,16 @@ package db
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
)
func Init(driver, connection string) (*sqlx.DB, error) {
if driver == "sqlite" {
dir := filepath.Dir(connection)
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, fmt.Errorf("failed to create data directory: %w", err)
}
}
db, err := sqlx.Connect(driver, connection)
// Init opens a PostgreSQL connection pool. The connection string must be a
// libpq-style URL or DSN supported by the pgx stdlib driver.
func Init(connection string) (*sqlx.DB, error) {
db, err := sqlx.Connect("pgx", connection)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
@ -30,10 +21,9 @@ func Init(driver, connection string) (*sqlx.DB, error) {
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
slog.Info("database connected", "driver", driver)
slog.Info("database connected", "driver", "pgx")
err = db.Ping()
if err != nil {
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}

View file

@ -9,22 +9,8 @@ import (
"github.com/pressly/goose/v3"
)
var dialectMap = map[string]string{
"sqlite": "sqlite3",
"pgx": "postgres",
}
func getDialect(driver string) string {
dialect, ok := dialectMap[driver]
if ok {
return dialect
}
return driver
}
func setupGoose(driver string) error {
err := goose.SetDialect(getDialect(driver))
if err != nil {
func setupGoose() error {
if err := goose.SetDialect("postgres"); err != nil {
return fmt.Errorf("failed to set dialect: %w", err)
}
@ -34,22 +20,18 @@ func setupGoose(driver string) error {
}
goose.SetBaseFS(migrationsDir)
return nil
}
func RunMigrations(db *sql.DB, driver string) error {
err := setupGoose(driver)
if err != nil {
func RunMigrations(db *sql.DB) error {
if err := setupGoose(); err != nil {
return err
}
err = goose.Up(db, ".")
if err != nil {
if err := goose.Up(db, "."); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
slog.Info("migrations completed successfully")
return nil
}

View file

@ -7,7 +7,7 @@ CREATE TABLE space_audit_logs (
action TEXT NOT NULL,
target_user_id TEXT REFERENCES users(id) ON DELETE SET NULL,
target_email TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

View file

@ -5,7 +5,7 @@ CREATE TABLE transaction_audit_logs (
transaction_id TEXT NOT NULL,
actor_id TEXT REFERENCES users(id) ON DELETE SET NULL,
action TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

View file

@ -0,0 +1,22 @@
-- +goose Up
-- +goose StatementBegin
-- The account-scoped activity feeds filter audit rows by metadata->>'account_id'.
-- A partial expression index is the right shape for this access pattern in
-- PostgreSQL 17: it is small (only the rows where the field exists), uses a
-- standard B-tree (cheap equality + ORDER BY created_at), and avoids the bloat
-- of a full GIN over the metadata document.
CREATE INDEX idx_space_audit_logs_account_id
ON space_audit_logs ((metadata->>'account_id'), created_at DESC)
WHERE action LIKE 'account.%';
CREATE INDEX idx_transaction_audit_logs_account_id
ON transaction_audit_logs ((metadata->>'account_id'), created_at DESC)
WHERE metadata ? 'account_id';
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP INDEX IF EXISTS idx_space_audit_logs_account_id;
DROP INDEX IF EXISTS idx_transaction_audit_logs_account_id;
-- +goose StatementEnd

View file

@ -0,0 +1,9 @@
package handler
import (
"testing"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
)
func TestMain(m *testing.M) { testutil.PostgresMain(m) }

View file

@ -0,0 +1,9 @@
package repository
import (
"testing"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
)
func TestMain(m *testing.M) { testutil.PostgresMain(m) }

View file

@ -0,0 +1,124 @@
package repository
import (
"encoding/json"
"testing"
"time"
"git.juancwu.dev/juancwu/budgit/internal/model"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func writeSpaceAuditLog(t *testing.T, repo SpaceAuditLogRepository, spaceID string, action model.SpaceAuditAction, actorID *string, metadata map[string]any, ts time.Time) *model.SpaceAuditLog {
t.Helper()
var meta []byte
if metadata != nil {
var err error
meta, err = json.Marshal(metadata)
require.NoError(t, err)
}
entry := &model.SpaceAuditLog{
ID: uuid.NewString(),
SpaceID: spaceID,
ActorID: actorID,
Action: action,
Metadata: meta,
CreatedAt: ts,
}
require.NoError(t, repo.Create(entry))
return entry
}
func TestSpaceAuditLogRepository_CreateAndList(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
repo := NewSpaceAuditLogRepository(dbi.DB)
actor := testutil.CreateTestUserWithName(t, dbi.DB, "audit-actor@example.com", strPtr("Actor Name"))
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Audit Space")
base := time.Now().Add(-time.Hour)
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionRenamed, &actor.ID, map[string]any{"old_name": "A", "new_name": "B"}, base)
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionMemberInvited, &actor.ID, nil, base.Add(10*time.Minute))
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionDeleted, &actor.ID, map[string]any{"space_name": "Audit Space"}, base.Add(20*time.Minute))
count, err := repo.CountBySpace(space.ID)
require.NoError(t, err)
assert.Equal(t, 3, count)
logs, err := repo.ListBySpace(space.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 3)
// Newest first.
assert.Equal(t, model.SpaceAuditActionDeleted, logs[0].Action)
assert.Equal(t, model.SpaceAuditActionMemberInvited, logs[1].Action)
assert.Equal(t, model.SpaceAuditActionRenamed, logs[2].Action)
// Actor join populated.
require.NotNil(t, logs[0].ActorName)
assert.Equal(t, "Actor Name", *logs[0].ActorName)
})
}
func TestSpaceAuditLogRepository_ListBySpace_Pagination(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
repo := NewSpaceAuditLogRepository(dbi.DB)
actor := testutil.CreateTestUser(t, dbi.DB, "page@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Paged Space")
base := time.Now().Add(-time.Hour)
for i := 0; i < 5; i++ {
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionRenamed, &actor.ID, nil, base.Add(time.Duration(i)*time.Minute))
}
page1, err := repo.ListBySpace(space.ID, 2, 0)
require.NoError(t, err)
require.Len(t, page1, 2)
page2, err := repo.ListBySpace(space.ID, 2, 2)
require.NoError(t, err)
require.Len(t, page2, 2)
// No overlap between pages.
assert.NotEqual(t, page1[0].ID, page2[0].ID)
assert.NotEqual(t, page1[1].ID, page2[0].ID)
})
}
func TestSpaceAuditLogRepository_ListAccountEvents_FiltersByMetadata(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
repo := NewSpaceAuditLogRepository(dbi.DB)
actor := testutil.CreateTestUser(t, dbi.DB, "acct-filter@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Filter Space")
acct1 := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Account 1")
acct2 := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Account 2")
base := time.Now().Add(-time.Hour)
// Account 1 events
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionAccountCreated, &actor.ID, map[string]any{"account_id": acct1.ID, "account_name": "Account 1"}, base)
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionAccountRenamed, &actor.ID, map[string]any{"account_id": acct1.ID, "old_name": "Account 1", "new_name": "Renamed"}, base.Add(time.Minute))
// Account 2 event
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionAccountCreated, &actor.ID, map[string]any{"account_id": acct2.ID, "account_name": "Account 2"}, base.Add(2*time.Minute))
// Non-account event in same space — must NOT appear in account-scoped query
writeSpaceAuditLog(t, repo, space.ID, model.SpaceAuditActionRenamed, &actor.ID, map[string]any{"old_name": "x", "new_name": "y"}, base.Add(3*time.Minute))
acct1Logs, err := repo.ListAccountEvents(acct1.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, acct1Logs, 2)
acct1Count, err := repo.CountAccountEvents(acct1.ID)
require.NoError(t, err)
assert.Equal(t, 2, acct1Count)
acct2Count, err := repo.CountAccountEvents(acct2.ID)
require.NoError(t, err)
assert.Equal(t, 1, acct2Count)
})
}
func strPtr(s string) *string { return &s }

View file

@ -0,0 +1,140 @@
package repository
import (
"encoding/json"
"testing"
"time"
"git.juancwu.dev/juancwu/budgit/internal/model"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func writeTxAuditLog(t *testing.T, repo TransactionAuditLogRepository, transactionID string, action model.TransactionAuditAction, actorID *string, metadata map[string]any, ts time.Time) *model.TransactionAuditLog {
t.Helper()
var meta []byte
if metadata != nil {
var err error
meta, err = json.Marshal(metadata)
require.NoError(t, err)
}
entry := &model.TransactionAuditLog{
ID: uuid.NewString(),
TransactionID: transactionID,
ActorID: actorID,
Action: action,
Metadata: meta,
CreatedAt: ts,
}
require.NoError(t, repo.Create(entry))
return entry
}
func TestTransactionAuditLogRepository_CreateAndListByTransaction(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
repo := NewTransactionAuditLogRepository(dbi.DB)
actor := testutil.CreateTestUserWithName(t, dbi.DB, "tx-audit@example.com", strPtr("Tx Actor"))
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Tx Audit Space")
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Acct")
txn := testutil.CreateTestTransaction(t, dbi.DB, account.ID, "Coffee", model.TransactionTypeWithdrawal, decimal.NewFromInt(5))
base := time.Now().Add(-time.Hour)
writeTxAuditLog(t, repo, txn.ID, model.TransactionAuditActionCreated, &actor.ID,
map[string]any{"account_id": account.ID, "transaction_type": "withdrawal", "title": "Coffee", "amount": "5.00"}, base)
writeTxAuditLog(t, repo, txn.ID, model.TransactionAuditActionEdited, &actor.ID,
map[string]any{"account_id": account.ID, "changes": map[string]any{"title": map[string]any{"old": "Coffee", "new": "Latte"}}}, base.Add(time.Minute))
count, err := repo.CountByTransaction(txn.ID)
require.NoError(t, err)
assert.Equal(t, 2, count)
logs, err := repo.ListByTransaction(txn.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 2)
// Newest first.
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
require.NotNil(t, logs[0].ActorName)
assert.Equal(t, "Tx Actor", *logs[0].ActorName)
})
}
func TestTransactionAuditLogRepository_ListByAccount_LiveAndDeletedFallback(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
repo := NewTransactionAuditLogRepository(dbi.DB)
actor := testutil.CreateTestUser(t, dbi.DB, "acct-list@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Acct List Space")
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Acct")
// Live transaction with audit entry
live := testutil.CreateTestTransaction(t, dbi.DB, account.ID, "Live", model.TransactionTypeDeposit, decimal.NewFromInt(10))
writeTxAuditLog(t, repo, live.ID, model.TransactionAuditActionCreated, &actor.ID,
map[string]any{"account_id": account.ID}, time.Now().Add(-2*time.Minute))
// Audit entry referencing a transaction that no longer exists.
// Resolution must fall back to metadata.account_id.
ghostID := uuid.NewString()
writeTxAuditLog(t, repo, ghostID, model.TransactionAuditActionDeleted, &actor.ID,
map[string]any{"account_id": account.ID, "title": "Ghost"}, time.Now().Add(-time.Minute))
// Audit entry for a different account — must not appear.
other := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Other")
otherTxn := testutil.CreateTestTransaction(t, dbi.DB, other.ID, "Other", model.TransactionTypeDeposit, decimal.NewFromInt(1))
writeTxAuditLog(t, repo, otherTxn.ID, model.TransactionAuditActionCreated, &actor.ID,
map[string]any{"account_id": other.ID}, time.Now())
count, err := repo.CountByAccount(account.ID)
require.NoError(t, err)
assert.Equal(t, 2, count, "should count live + ghost-via-metadata")
logs, err := repo.ListByAccount(account.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 2)
// Confirm both kinds present (one live, one ghost).
ids := []string{logs[0].TransactionID, logs[1].TransactionID}
assert.Contains(t, ids, live.ID)
assert.Contains(t, ids, ghostID)
})
}
func TestTransactionAuditLogRepository_ListBySpace(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
repo := NewTransactionAuditLogRepository(dbi.DB)
actor := testutil.CreateTestUser(t, dbi.DB, "space-list@example.com", nil)
// Two spaces, each with an account and a transaction.
spaceA := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Space A")
acctA := testutil.CreateTestAccount(t, dbi.DB, spaceA.ID, "Acct A")
txnA := testutil.CreateTestTransaction(t, dbi.DB, acctA.ID, "txnA", model.TransactionTypeDeposit, decimal.NewFromInt(1))
writeTxAuditLog(t, repo, txnA.ID, model.TransactionAuditActionCreated, &actor.ID,
map[string]any{"account_id": acctA.ID}, time.Now().Add(-time.Minute))
spaceB := testutil.CreateTestSpace(t, dbi.DB, actor.ID, "Space B")
acctB := testutil.CreateTestAccount(t, dbi.DB, spaceB.ID, "Acct B")
txnB := testutil.CreateTestTransaction(t, dbi.DB, acctB.ID, "txnB", model.TransactionTypeDeposit, decimal.NewFromInt(1))
writeTxAuditLog(t, repo, txnB.ID, model.TransactionAuditActionCreated, &actor.ID,
map[string]any{"account_id": acctB.ID}, time.Now())
// Ghost in space A (deleted txn).
ghostID := uuid.NewString()
writeTxAuditLog(t, repo, ghostID, model.TransactionAuditActionDeleted, &actor.ID,
map[string]any{"account_id": acctA.ID}, time.Now().Add(-30*time.Second))
countA, err := repo.CountBySpace(spaceA.ID)
require.NoError(t, err)
assert.Equal(t, 2, countA)
countB, err := repo.CountBySpace(spaceB.ID)
require.NoError(t, err)
assert.Equal(t, 1, countB)
logsA, err := repo.ListBySpace(spaceA.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logsA, 2)
})
}

View file

@ -0,0 +1,9 @@
package routes
import (
"testing"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
)
func TestMain(m *testing.M) { testutil.PostgresMain(m) }

View file

@ -16,7 +16,6 @@ import (
func newTestApp(dbi testutil.DBInfo) *app.App {
cfg := testutil.TestConfig()
cfg.DBDriver = dbi.Driver
userRepo := repository.NewUserRepository(dbi.DB)
tokenRepo := repository.NewTokenRepository(dbi.DB)

View file

@ -0,0 +1,213 @@
package service
import (
"errors"
"testing"
"time"
"git.juancwu.dev/juancwu/budgit/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubSpaceAuditRepo serves canned responses for the activity-merger tests so we can
// focus on the merge/sort/pagination logic without a real DB.
type stubSpaceAuditRepo struct {
listAccount []*model.SpaceAuditLogWithActor
countAccount int
listSpace []*model.SpaceAuditLogWithActor
countSpace int
err error
}
func (s *stubSpaceAuditRepo) Create(*model.SpaceAuditLog) error { return nil }
func (s *stubSpaceAuditRepo) ListBySpace(_ string, limit, _ int) ([]*model.SpaceAuditLogWithActor, error) {
if s.err != nil {
return nil, s.err
}
return firstN(s.listSpace, limit), nil
}
func (s *stubSpaceAuditRepo) CountBySpace(string) (int, error) { return s.countSpace, s.err }
func (s *stubSpaceAuditRepo) ListAccountEvents(_ string, limit, _ int) ([]*model.SpaceAuditLogWithActor, error) {
if s.err != nil {
return nil, s.err
}
return firstN(s.listAccount, limit), nil
}
func (s *stubSpaceAuditRepo) CountAccountEvents(string) (int, error) { return s.countAccount, s.err }
type stubTxAuditRepo struct {
listAccount []*model.TransactionAuditLogWithActor
countAccount int
listSpace []*model.TransactionAuditLogWithActor
countSpace int
err error
}
func (s *stubTxAuditRepo) Create(*model.TransactionAuditLog) error { return nil }
func (s *stubTxAuditRepo) ListByTransaction(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
return nil, nil
}
func (s *stubTxAuditRepo) CountByTransaction(string) (int, error) { return 0, nil }
func (s *stubTxAuditRepo) ListByAccount(_ string, limit, _ int) ([]*model.TransactionAuditLogWithActor, error) {
if s.err != nil {
return nil, s.err
}
return firstNTx(s.listAccount, limit), nil
}
func (s *stubTxAuditRepo) CountByAccount(string) (int, error) { return s.countAccount, s.err }
func (s *stubTxAuditRepo) ListBySpace(_ string, limit, _ int) ([]*model.TransactionAuditLogWithActor, error) {
if s.err != nil {
return nil, s.err
}
return firstNTx(s.listSpace, limit), nil
}
func (s *stubTxAuditRepo) CountBySpace(string) (int, error) { return s.countSpace, s.err }
func firstN(s []*model.SpaceAuditLogWithActor, n int) []*model.SpaceAuditLogWithActor {
if n >= len(s) {
return s
}
return s[:n]
}
func firstNTx(s []*model.TransactionAuditLogWithActor, n int) []*model.TransactionAuditLogWithActor {
if n >= len(s) {
return s
}
return s[:n]
}
func spaceLog(action model.SpaceAuditAction, ts time.Time) *model.SpaceAuditLogWithActor {
return &model.SpaceAuditLogWithActor{
SpaceAuditLog: model.SpaceAuditLog{Action: action, CreatedAt: ts},
}
}
func txLog(action model.TransactionAuditAction, ts time.Time) *model.TransactionAuditLogWithActor {
return &model.TransactionAuditLogWithActor{
TransactionAuditLog: model.TransactionAuditLog{Action: action, CreatedAt: ts},
}
}
func TestAccountActivityService_List_MergesAndSortsByTimestamp(t *testing.T) {
now := time.Now()
spaceRepo := &stubSpaceAuditRepo{
listAccount: []*model.SpaceAuditLogWithActor{
spaceLog(model.SpaceAuditActionAccountRenamed, now.Add(-1*time.Minute)),
spaceLog(model.SpaceAuditActionAccountCreated, now.Add(-10*time.Minute)),
},
countAccount: 2,
}
txRepo := &stubTxAuditRepo{
listAccount: []*model.TransactionAuditLogWithActor{
txLog(model.TransactionAuditActionEdited, now), // newest overall
txLog(model.TransactionAuditActionCreated, now.Add(-5*time.Minute)),
txLog(model.TransactionAuditActionDeleted, now.Add(-15*time.Minute)), // oldest overall
},
countAccount: 3,
}
svc := NewAccountActivityService(NewSpaceAuditLogService(spaceRepo), NewTransactionAuditLogService(txRepo))
rows, err := svc.List("acct-1", 10, 0)
require.NoError(t, err)
require.Len(t, rows, 5)
// Strictly descending by timestamp.
for i := 1; i < len(rows); i++ {
assert.False(t, rows[i].Timestamp().After(rows[i-1].Timestamp()),
"row %d (%v) is newer than row %d (%v)", i, rows[i].Timestamp(), i-1, rows[i-1].Timestamp())
}
// Top row is the transaction edit at `now`.
require.NotNil(t, rows[0].TxLog)
assert.Equal(t, model.TransactionAuditActionEdited, rows[0].TxLog.Action)
}
func TestAccountActivityService_List_Pagination(t *testing.T) {
now := time.Now()
spaceRepo := &stubSpaceAuditRepo{
listAccount: []*model.SpaceAuditLogWithActor{
spaceLog(model.SpaceAuditActionAccountCreated, now.Add(-30*time.Minute)),
},
}
txRepo := &stubTxAuditRepo{
listAccount: []*model.TransactionAuditLogWithActor{
txLog(model.TransactionAuditActionEdited, now.Add(-10*time.Minute)),
txLog(model.TransactionAuditActionEdited, now.Add(-20*time.Minute)),
txLog(model.TransactionAuditActionEdited, now.Add(-40*time.Minute)),
},
}
svc := NewAccountActivityService(NewSpaceAuditLogService(spaceRepo), NewTransactionAuditLogService(txRepo))
page1, err := svc.List("a", 2, 0)
require.NoError(t, err)
require.Len(t, page1, 2)
page2, err := svc.List("a", 2, 2)
require.NoError(t, err)
require.Len(t, page2, 2)
// Total of 4 entries; page2[1] is the oldest.
assert.Equal(t, now.Add(-40*time.Minute).Unix(), page2[1].Timestamp().Unix())
}
func TestAccountActivityService_List_OffsetPastEndReturnsEmpty(t *testing.T) {
svc := NewAccountActivityService(
NewSpaceAuditLogService(&stubSpaceAuditRepo{}),
NewTransactionAuditLogService(&stubTxAuditRepo{}),
)
rows, err := svc.List("a", 10, 100)
require.NoError(t, err)
assert.Nil(t, rows)
}
func TestAccountActivityService_Count_SumsBothSources(t *testing.T) {
svc := NewAccountActivityService(
NewSpaceAuditLogService(&stubSpaceAuditRepo{countAccount: 3}),
NewTransactionAuditLogService(&stubTxAuditRepo{countAccount: 7}),
)
count, err := svc.Count("a")
require.NoError(t, err)
assert.Equal(t, 10, count)
}
func TestAccountActivityService_List_PropagatesError(t *testing.T) {
svc := NewAccountActivityService(
NewSpaceAuditLogService(&stubSpaceAuditRepo{err: errors.New("boom")}),
NewTransactionAuditLogService(&stubTxAuditRepo{}),
)
_, err := svc.List("a", 10, 0)
assert.Error(t, err)
}
func TestAccountActivityService_ListSpace_MergesSpaceAndTxFeeds(t *testing.T) {
now := time.Now()
spaceRepo := &stubSpaceAuditRepo{
listSpace: []*model.SpaceAuditLogWithActor{
spaceLog(model.SpaceAuditActionRenamed, now.Add(-3*time.Minute)),
spaceLog(model.SpaceAuditActionMemberInvited, now.Add(-5*time.Minute)),
},
countSpace: 2,
}
txRepo := &stubTxAuditRepo{
listSpace: []*model.TransactionAuditLogWithActor{
txLog(model.TransactionAuditActionCreated, now),
txLog(model.TransactionAuditActionEdited, now.Add(-4*time.Minute)),
},
countSpace: 2,
}
svc := NewAccountActivityService(NewSpaceAuditLogService(spaceRepo), NewTransactionAuditLogService(txRepo))
rows, err := svc.ListSpace("space", 10, 0)
require.NoError(t, err)
require.Len(t, rows, 4)
require.NotNil(t, rows[0].TxLog, "newest is the tx-created row at `now`")
}
func TestAccountActivityService_CountSpace_SumsBothSources(t *testing.T) {
svc := NewAccountActivityService(
NewSpaceAuditLogService(&stubSpaceAuditRepo{countSpace: 4}),
NewTransactionAuditLogService(&stubTxAuditRepo{countSpace: 6}),
)
count, err := svc.CountSpace("s")
require.NoError(t, err)
assert.Equal(t, 10, count)
}

View file

@ -0,0 +1,112 @@
package service
import (
"encoding/json"
"testing"
"git.juancwu.dev/juancwu/budgit/internal/model"
"git.juancwu.dev/juancwu/budgit/internal/repository"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAccountService_CreateAccount_RecordsAudit(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
accountRepo := repository.NewAccountRepository(dbi.DB)
auditRepo := repository.NewSpaceAuditLogRepository(dbi.DB)
auditSvc := NewSpaceAuditLogService(auditRepo)
svc := NewAccountService(accountRepo)
svc.SetAuditLogger(auditSvc)
user := testutil.CreateTestUser(t, dbi.DB, "acct-create-audit@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account, err := svc.CreateAccount(space.ID, "Checking", user.ID)
require.NoError(t, err)
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 1)
assert.Equal(t, model.SpaceAuditActionAccountCreated, logs[0].Action)
var meta map[string]any
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, account.ID, meta["account_id"])
assert.Equal(t, "Checking", meta["account_name"])
})
}
func TestAccountService_RenameAccount_RecordsAuditOnlyWhenChanged(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
accountRepo := repository.NewAccountRepository(dbi.DB)
auditRepo := repository.NewSpaceAuditLogRepository(dbi.DB)
svc := NewAccountService(accountRepo)
svc.SetAuditLogger(NewSpaceAuditLogService(auditRepo))
user := testutil.CreateTestUser(t, dbi.DB, "acct-rename-audit@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Old")
// Rename to a new value records an audit row.
require.NoError(t, svc.RenameAccount(account.ID, "New", user.ID))
// Renaming to the same value does not.
require.NoError(t, svc.RenameAccount(account.ID, "New", user.ID))
count, err := auditRepo.CountAccountEvents(account.ID)
require.NoError(t, err)
assert.Equal(t, 1, count)
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
require.NoError(t, err)
var meta map[string]any
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, "Old", meta["old_name"])
assert.Equal(t, "New", meta["new_name"])
})
}
func TestAccountService_DeleteAccount_RecordsAuditBeforeDelete(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
accountRepo := repository.NewAccountRepository(dbi.DB)
auditRepo := repository.NewSpaceAuditLogRepository(dbi.DB)
svc := NewAccountService(accountRepo)
svc.SetAuditLogger(NewSpaceAuditLogService(auditRepo))
user := testutil.CreateTestUser(t, dbi.DB, "acct-delete-audit@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Target")
require.NoError(t, svc.DeleteAccount(account.ID, user.ID))
// Account is gone.
_, err := accountRepo.ByID(account.ID)
require.Error(t, err)
// Audit row still exists and captured the pre-delete name (no FK on metadata).
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 1)
assert.Equal(t, model.SpaceAuditActionAccountDeleted, logs[0].Action)
var meta map[string]any
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, "Target", meta["account_name"])
})
}
func TestAccountService_NoAuditLoggerSet_DoesNotPanic(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
// SetAuditLogger is intentionally optional so existing tests/callers that
// don't care about audit don't have to wire it.
svc := NewAccountService(repository.NewAccountRepository(dbi.DB))
user := testutil.CreateTestUser(t, dbi.DB, "no-audit@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account, err := svc.CreateAccount(space.ID, "x", user.ID)
require.NoError(t, err)
require.NoError(t, svc.RenameAccount(account.ID, "y", user.ID))
require.NoError(t, svc.DeleteAccount(account.ID, user.ID))
})
}

View file

@ -0,0 +1,9 @@
package service
import (
"testing"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
)
func TestMain(m *testing.M) { testutil.PostgresMain(m) }

View file

@ -0,0 +1,100 @@
package service
import (
"encoding/json"
"errors"
"testing"
"git.juancwu.dev/juancwu/budgit/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeSpaceAuditRepo struct {
created []*model.SpaceAuditLog
failNext error
}
func (f *fakeSpaceAuditRepo) Create(log *model.SpaceAuditLog) error {
if f.failNext != nil {
err := f.failNext
f.failNext = nil
return err
}
f.created = append(f.created, log)
return nil
}
func (f *fakeSpaceAuditRepo) ListBySpace(string, int, int) ([]*model.SpaceAuditLogWithActor, error) {
return nil, nil
}
func (f *fakeSpaceAuditRepo) CountBySpace(string) (int, error) { return 0, nil }
func (f *fakeSpaceAuditRepo) ListAccountEvents(string, int, int) ([]*model.SpaceAuditLogWithActor, error) {
return nil, nil
}
func (f *fakeSpaceAuditRepo) CountAccountEvents(string) (int, error) { return 0, nil }
func TestSpaceAuditLogService_Record_PersistsEntry(t *testing.T) {
repo := &fakeSpaceAuditRepo{}
svc := NewSpaceAuditLogService(repo)
svc.Record(RecordOptions{
SpaceID: "space-1",
ActorID: "actor-1",
Action: model.SpaceAuditActionRenamed,
TargetUserID: "target-1",
TargetEmail: "target@example.com",
Metadata: map[string]any{"old_name": "A", "new_name": "B"},
})
require.Len(t, repo.created, 1)
got := repo.created[0]
assert.Equal(t, "space-1", got.SpaceID)
require.NotNil(t, got.ActorID)
assert.Equal(t, "actor-1", *got.ActorID)
require.NotNil(t, got.TargetUserID)
assert.Equal(t, "target-1", *got.TargetUserID)
require.NotNil(t, got.TargetEmail)
assert.Equal(t, "target@example.com", *got.TargetEmail)
assert.Equal(t, model.SpaceAuditActionRenamed, got.Action)
assert.NotEmpty(t, got.ID)
assert.False(t, got.CreatedAt.IsZero())
var meta map[string]any
require.NoError(t, json.Unmarshal(got.Metadata, &meta))
assert.Equal(t, "A", meta["old_name"])
assert.Equal(t, "B", meta["new_name"])
}
func TestSpaceAuditLogService_Record_OmitsBlankOptionalFields(t *testing.T) {
repo := &fakeSpaceAuditRepo{}
svc := NewSpaceAuditLogService(repo)
svc.Record(RecordOptions{
SpaceID: "space-1",
Action: model.SpaceAuditActionDeleted,
})
require.Len(t, repo.created, 1)
got := repo.created[0]
assert.Nil(t, got.ActorID)
assert.Nil(t, got.TargetUserID)
assert.Nil(t, got.TargetEmail)
assert.Empty(t, got.Metadata)
}
func TestSpaceAuditLogService_Record_SwallowsRepoError(t *testing.T) {
// Audit failures must not bubble up to break the user's action.
repo := &fakeSpaceAuditRepo{failNext: errors.New("boom")}
svc := NewSpaceAuditLogService(repo)
assert.NotPanics(t, func() {
svc.Record(RecordOptions{SpaceID: "s", Action: model.SpaceAuditActionRenamed})
})
assert.Empty(t, repo.created)
}
func TestSpaceAuditLogService_Record_NilReceiverIsNoOp(t *testing.T) {
var svc *SpaceAuditLogService
assert.NotPanics(t, func() {
svc.Record(RecordOptions{SpaceID: "s", Action: model.SpaceAuditActionRenamed})
})
}

View file

@ -0,0 +1,76 @@
package service
import (
"encoding/json"
"errors"
"testing"
"git.juancwu.dev/juancwu/budgit/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeTxAuditRepo struct {
created []*model.TransactionAuditLog
failNext error
}
func (f *fakeTxAuditRepo) Create(log *model.TransactionAuditLog) error {
if f.failNext != nil {
err := f.failNext
f.failNext = nil
return err
}
f.created = append(f.created, log)
return nil
}
func (f *fakeTxAuditRepo) ListByTransaction(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
return nil, nil
}
func (f *fakeTxAuditRepo) CountByTransaction(string) (int, error) { return 0, nil }
func (f *fakeTxAuditRepo) ListByAccount(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
return nil, nil
}
func (f *fakeTxAuditRepo) CountByAccount(string) (int, error) { return 0, nil }
func (f *fakeTxAuditRepo) ListBySpace(string, int, int) ([]*model.TransactionAuditLogWithActor, error) {
return nil, nil
}
func (f *fakeTxAuditRepo) CountBySpace(string) (int, error) { return 0, nil }
func TestTransactionAuditLogService_Record_PersistsEntry(t *testing.T) {
repo := &fakeTxAuditRepo{}
svc := NewTransactionAuditLogService(repo)
svc.Record(TransactionRecordOptions{
TransactionID: "txn-1",
ActorID: "actor-1",
Action: model.TransactionAuditActionEdited,
Metadata: map[string]any{"changes": map[string]any{"title": "x"}},
})
require.Len(t, repo.created, 1)
got := repo.created[0]
assert.Equal(t, "txn-1", got.TransactionID)
require.NotNil(t, got.ActorID)
assert.Equal(t, "actor-1", *got.ActorID)
assert.Equal(t, model.TransactionAuditActionEdited, got.Action)
var meta map[string]any
require.NoError(t, json.Unmarshal(got.Metadata, &meta))
assert.Contains(t, meta, "changes")
}
func TestTransactionAuditLogService_Record_SwallowsRepoError(t *testing.T) {
repo := &fakeTxAuditRepo{failNext: errors.New("boom")}
svc := NewTransactionAuditLogService(repo)
assert.NotPanics(t, func() {
svc.Record(TransactionRecordOptions{TransactionID: "x", Action: model.TransactionAuditActionEdited})
})
}
func TestTransactionAuditLogService_Record_NilReceiverIsNoOp(t *testing.T) {
var svc *TransactionAuditLogService
assert.NotPanics(t, func() {
svc.Record(TransactionRecordOptions{TransactionID: "x", Action: model.TransactionAuditActionEdited})
})
}

View file

@ -0,0 +1,265 @@
package service
import (
"encoding/json"
"testing"
"time"
"git.juancwu.dev/juancwu/budgit/internal/model"
"git.juancwu.dev/juancwu/budgit/internal/repository"
"git.juancwu.dev/juancwu/budgit/internal/testutil"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// txnFixture builds a fully wired TransactionService against a real DB along with
// the helper repos the tests need to inspect post-state.
type txnFixture struct {
svc *TransactionService
txAudit repository.TransactionAuditLogRepository
accounts repository.AccountRepository
user *model.User
account *model.Account
}
func newTxnFixture(t *testing.T, dbi testutil.DBInfo) *txnFixture {
t.Helper()
txnRepo := repository.NewTransactionRepository(dbi.DB)
categoryRepo := repository.NewCategoryRepository(dbi.DB)
accountRepo := repository.NewAccountRepository(dbi.DB)
auditRepo := repository.NewTransactionAuditLogRepository(dbi.DB)
accountSvc := NewAccountService(accountRepo)
auditSvc := NewTransactionAuditLogService(auditRepo)
svc := NewTransactionService(txnRepo, categoryRepo, accountSvc)
svc.SetAuditLogger(auditSvc)
user := testutil.CreateTestUser(t, dbi.DB, t.Name()+"@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Acct")
return &txnFixture{
svc: svc,
txAudit: auditRepo,
accounts: accountRepo,
user: user,
account: account,
}
}
func TestTransactionService_Deposit_RecordsAuditAndUpdatesBalance(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
txn, err := f.svc.Deposit(DepositInput{
AccountID: f.account.ID,
Title: "Paycheck",
Amount: decimal.NewFromInt(100),
OccurredAt: time.Now(),
ActorID: f.user.ID,
})
require.NoError(t, err)
assert.True(t, decimal.NewFromInt(100).Equal(txn.Value))
updated, err := f.accounts.ByID(f.account.ID)
require.NoError(t, err)
assert.True(t, decimal.NewFromInt(100).Equal(updated.Balance))
logs, err := f.txAudit.ListByTransaction(txn.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 1)
assert.Equal(t, model.TransactionAuditActionCreated, logs[0].Action)
var meta map[string]any
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, "deposit", meta["transaction_type"])
assert.Equal(t, f.account.ID, meta["account_id"])
assert.Equal(t, "Paycheck", meta["title"])
assert.Equal(t, "100.00", meta["amount"])
})
}
func TestTransactionService_PayBill_RecordsAuditAndDebitsBalance(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
// Seed some balance via deposit.
_, err := f.svc.Deposit(DepositInput{
AccountID: f.account.ID, Title: "seed", Amount: decimal.NewFromInt(50), OccurredAt: time.Now(), ActorID: f.user.ID,
})
require.NoError(t, err)
txn, err := f.svc.PayBill(PayBillInput{
AccountID: f.account.ID,
Title: "Rent",
Amount: decimal.NewFromInt(20),
OccurredAt: time.Now(),
ActorID: f.user.ID,
})
require.NoError(t, err)
updated, err := f.accounts.ByID(f.account.ID)
require.NoError(t, err)
assert.True(t, decimal.NewFromInt(30).Equal(updated.Balance))
logs, err := f.txAudit.ListByTransaction(txn.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 1)
var meta map[string]any
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, "withdrawal", meta["transaction_type"])
})
}
func TestTransactionService_UpdateDeposit_RebalancesAndDiffs(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
original, err := f.svc.Deposit(DepositInput{
AccountID: f.account.ID, Title: "Old", Amount: decimal.NewFromInt(40),
OccurredAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
ActorID: f.user.ID,
})
require.NoError(t, err)
_, err = f.svc.UpdateDeposit(UpdateDepositInput{
TransactionID: original.ID,
Title: "New",
Amount: decimal.NewFromInt(60), // +20 net
OccurredAt: time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC),
ActorID: f.user.ID,
})
require.NoError(t, err)
// Balance reflects the swap (40 → 60 means +20 from 40 baseline).
updated, err := f.accounts.ByID(f.account.ID)
require.NoError(t, err)
assert.True(t, decimal.NewFromInt(60).Equal(updated.Balance))
// 2 audit rows: created + edited (newest first).
logs, err := f.txAudit.ListByTransaction(original.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 2)
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
var meta struct {
AccountID string `json:"account_id"`
Changes map[string]map[string]any `json:"changes"`
}
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, f.account.ID, meta.AccountID)
assert.Contains(t, meta.Changes, "title")
assert.Equal(t, "Old", meta.Changes["title"]["old"])
assert.Equal(t, "New", meta.Changes["title"]["new"])
assert.Contains(t, meta.Changes, "amount")
assert.Equal(t, "40.00", meta.Changes["amount"]["old"])
assert.Equal(t, "60.00", meta.Changes["amount"]["new"])
assert.Contains(t, meta.Changes, "occurred_at")
})
}
func TestTransactionService_UpdateDeposit_NoChanges_NoAudit(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
original, err := f.svc.Deposit(DepositInput{
AccountID: f.account.ID, Title: "Same", Amount: decimal.NewFromInt(10),
OccurredAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
ActorID: f.user.ID,
})
require.NoError(t, err)
// Update with identical values.
_, err = f.svc.UpdateDeposit(UpdateDepositInput{
TransactionID: original.ID,
Title: "Same",
Amount: decimal.NewFromInt(10),
OccurredAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
ActorID: f.user.ID,
})
require.NoError(t, err)
// Only the original `created` audit row exists; no `edited` row.
count, err := f.txAudit.CountByTransaction(original.ID)
require.NoError(t, err)
assert.Equal(t, 1, count)
})
}
func TestTransactionService_UpdateBill_RebalancesAndDiffs(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
// Seed funds and then a bill.
_, err := f.svc.Deposit(DepositInput{
AccountID: f.account.ID, Title: "seed", Amount: decimal.NewFromInt(100), OccurredAt: time.Now(), ActorID: f.user.ID,
})
require.NoError(t, err)
bill, err := f.svc.PayBill(PayBillInput{
AccountID: f.account.ID, Title: "Cable", Amount: decimal.NewFromInt(30), OccurredAt: time.Now(), ActorID: f.user.ID,
})
require.NoError(t, err)
_, err = f.svc.UpdateBill(UpdateBillInput{
TransactionID: bill.ID,
Title: "Internet",
Amount: decimal.NewFromInt(40), // -10 vs original
OccurredAt: bill.OccurredAt,
ActorID: f.user.ID,
})
require.NoError(t, err)
// 100 - 40 = 60.
updated, err := f.accounts.ByID(f.account.ID)
require.NoError(t, err)
assert.True(t, decimal.NewFromInt(60).Equal(updated.Balance))
logs, err := f.txAudit.ListByTransaction(bill.ID, 10, 0)
require.NoError(t, err)
require.Len(t, logs, 2)
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
})
}
func TestTransactionService_UpdateDeposit_RejectsBillTransaction(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
_, err := f.svc.Deposit(DepositInput{
AccountID: f.account.ID, Title: "seed", Amount: decimal.NewFromInt(50), OccurredAt: time.Now(), ActorID: f.user.ID,
})
require.NoError(t, err)
bill, err := f.svc.PayBill(PayBillInput{
AccountID: f.account.ID, Title: "Bill", Amount: decimal.NewFromInt(10), OccurredAt: time.Now(), ActorID: f.user.ID,
})
require.NoError(t, err)
_, err = f.svc.UpdateDeposit(UpdateDepositInput{
TransactionID: bill.ID,
Title: "x",
Amount: decimal.NewFromInt(1),
OccurredAt: time.Now(),
ActorID: f.user.ID,
})
require.Error(t, err)
})
}
func TestTransactionService_Validations(t *testing.T) {
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
f := newTxnFixture(t, dbi)
_, err := f.svc.Deposit(DepositInput{AccountID: f.account.ID, Amount: decimal.NewFromInt(1), OccurredAt: time.Now()})
assert.Error(t, err, "blank title")
_, err = f.svc.Deposit(DepositInput{AccountID: f.account.ID, Title: "x", Amount: decimal.NewFromInt(0), OccurredAt: time.Now()})
assert.Error(t, err, "zero amount")
_, err = f.svc.Deposit(DepositInput{AccountID: f.account.ID, Title: "x", Amount: decimal.NewFromInt(1)})
assert.Error(t, err, "missing date")
_, err = f.svc.PayBill(PayBillInput{Title: "x", Amount: decimal.NewFromInt(1), OccurredAt: time.Now()})
assert.Error(t, err, "missing account id")
})
}

View file

@ -23,8 +23,7 @@ func TestConfig() *config.Config {
AppURL: "http://localhost:9999",
Host: "127.0.0.1",
Port: "9999",
DBDriver: "sqlite",
DBConnection: ":memory:",
DBConnection: "",
JWTSecret: "test-secret-key-for-testing-only",
JWTExpiry: 24 * time.Hour,
TokenMagicLinkExpiry: 10 * time.Minute,

View file

@ -0,0 +1,118 @@
package testutil
import (
"database/sql"
"fmt"
"net"
"os"
"os/exec"
"testing"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
)
// PostgresMain is the TestMain entry point used by every package whose tests touch
// the database. It guarantees a running PostgreSQL 17 instance for the duration of
// the test binary:
//
// - If BUDGIT_TEST_POSTGRES_URL is already set, it is used as-is. CI and `task test`
// hit this path.
// - Otherwise an ephemeral `postgres:17-alpine` container is started on a free
// local port, BUDGIT_TEST_POSTGRES_URL is exported to it for the test process,
// and the container is removed when the test binary exits — even on panic, via
// a deferred cleanup around m.Run().
//
// Usage in each test package:
//
// func TestMain(m *testing.M) { testutil.PostgresMain(m) }
func PostgresMain(m *testing.M) {
if os.Getenv("BUDGIT_TEST_POSTGRES_URL") != "" {
os.Exit(m.Run())
}
if _, err := exec.LookPath("docker"); err != nil {
fmt.Fprintln(os.Stderr, "testutil.PostgresMain: BUDGIT_TEST_POSTGRES_URL is unset and `docker` is not on PATH; cannot run db tests")
os.Exit(1)
}
port, err := freePort()
if err != nil {
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: failed to find free port: %v\n", err)
os.Exit(1)
}
containerName := fmt.Sprintf("budgit-test-pg-%d-%d", os.Getpid(), time.Now().UnixNano())
startCmd := exec.Command("docker", "run", "--rm", "-d",
"--name", containerName,
"-p", fmt.Sprintf("%d:5432", port),
"-e", "POSTGRES_USER=budgit_test",
"-e", "POSTGRES_PASSWORD=testpass",
"-e", "POSTGRES_DB=budgit_test",
// tmpfs for the data dir keeps tests fast — we don't care about durability.
"--tmpfs", "/var/lib/postgresql/data:rw",
"postgres:17-alpine",
)
if out, err := startCmd.CombinedOutput(); err != nil {
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: docker run failed: %v\n%s\n", err, out)
os.Exit(1)
}
stop := func() {
// `docker rm -f` because --rm only fires on a clean exit; force-stop the
// container regardless of state so leftover containers don't accumulate.
_ = exec.Command("docker", "rm", "-f", containerName).Run()
}
url := fmt.Sprintf("postgres://budgit_test:testpass@127.0.0.1:%d/budgit_test?sslmode=disable", port)
if err := waitForPostgres(url, 60*time.Second); err != nil {
stop()
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: postgres did not become ready: %v\n", err)
os.Exit(1)
}
if err := os.Setenv("BUDGIT_TEST_POSTGRES_URL", url); err != nil {
stop()
fmt.Fprintf(os.Stderr, "testutil.PostgresMain: setenv failed: %v\n", err)
os.Exit(1)
}
// Run tests, then ALWAYS stop the container — including on panic.
code := func() int {
defer stop()
return m.Run()
}()
os.Exit(code)
}
func freePort() (int, error) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
// waitForPostgres polls until a real client connection succeeds. pg_isready isn't
// sufficient under load — under parallel `go test ./...` we've seen it report ready
// while client connections still fail with "unexpected EOF" because the server is
// still finishing startup.
func waitForPostgres(url string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := sql.Open("pgx", url)
if err == nil {
if err = db.Ping(); err == nil {
_ = db.Close()
return nil
}
_ = db.Close()
}
lastErr = err
time.Sleep(200 * time.Millisecond)
}
return fmt.Errorf("timed out after %s: %w", timeout, lastErr)
}

View file

@ -7,6 +7,7 @@ import (
"git.juancwu.dev/juancwu/budgit/internal/model"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
)
// CreateTestUser inserts a user directly into the database.
@ -79,6 +80,52 @@ func CreateTestSpace(t *testing.T, db *sqlx.DB, ownerID, name string) *model.Spa
return space
}
// CreateTestAccount inserts an account directly into the database.
func CreateTestAccount(t *testing.T, db *sqlx.DB, spaceID, name string) *model.Account {
t.Helper()
now := time.Now()
account := &model.Account{
ID: uuid.NewString(),
Name: name,
SpaceID: spaceID,
Balance: decimal.Zero,
CreatedAt: now,
UpdatedAt: now,
}
_, err := db.Exec(
`INSERT INTO accounts (id, name, space_id, balance, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`,
account.ID, account.Name, account.SpaceID, account.Balance, account.CreatedAt, account.UpdatedAt,
)
if err != nil {
t.Fatalf("CreateTestAccount: %v", err)
}
return account
}
// CreateTestTransaction inserts a transaction directly into the database.
func CreateTestTransaction(t *testing.T, db *sqlx.DB, accountID, title string, txnType model.TransactionType, amount decimal.Decimal) *model.Transaction {
t.Helper()
now := time.Now()
txn := &model.Transaction{
ID: uuid.NewString(),
Value: amount,
Type: txnType,
AccountID: accountID,
Title: title,
OccurredAt: now,
CreatedAt: now,
UpdatedAt: now,
}
_, err := db.Exec(
`INSERT INTO transactions (id, value, type, account_id, title, description, occurred_at, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
txn.ID, txn.Value, txn.Type, txn.AccountID, txn.Title, txn.Description, txn.OccurredAt, txn.CreatedAt, txn.UpdatedAt,
)
if err != nil {
t.Fatalf("CreateTestTransaction: %v", err)
}
return txn
}
// CreateTestToken inserts a token directly into the database.
func CreateTestToken(t *testing.T, db *sqlx.DB, userID, tokenType, tokenString string, expiresAt time.Time) *model.Token {
t.Helper()

View file

@ -10,26 +10,24 @@ import (
"github.com/jmoiron/sqlx"
)
// DBInfo holds a test database connection and its driver name.
// DBInfo holds a test database connection.
type DBInfo struct {
DB *sqlx.DB
Driver string
DB *sqlx.DB
}
// ForEachDB runs the given test function against both SQLite and PostgreSQL.
// PostgreSQL tests are skipped when BUDGIT_TEST_POSTGRES_URL is unset.
// ForEachDB runs the test function against PostgreSQL. Skips when
// BUDGIT_TEST_POSTGRES_URL is not set so quick local runs don't fail
// without a database. CI must always set it.
//
// Each test gets its own schema for isolation; the schema is dropped on
// cleanup. The function name is preserved for backwards compatibility,
// although there is now only one engine.
func ForEachDB(t *testing.T, fn func(t *testing.T, dbi DBInfo)) {
t.Helper()
t.Run("sqlite", func(t *testing.T) {
t.Parallel()
dbi := newSQLiteDB(t)
fn(t, dbi)
})
pgURL := os.Getenv("BUDGIT_TEST_POSTGRES_URL")
if pgURL == "" {
t.Log("skipping postgres tests: BUDGIT_TEST_POSTGRES_URL not set")
t.Skip("skipping db tests: BUDGIT_TEST_POSTGRES_URL not set")
return
}
@ -40,55 +38,25 @@ func ForEachDB(t *testing.T, fn func(t *testing.T, dbi DBInfo)) {
})
}
func newSQLiteDB(t *testing.T) DBInfo {
t.Helper()
// Use a unique in-memory database per test via a unique DSN.
// Each file::memory:?cache=shared&name=X uses a separate in-memory DB.
safeName := strings.ReplaceAll(t.Name(), "/", "_")
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_pragma=foreign_keys(1)", safeName)
sqliteDB, err := sqlx.Connect("sqlite", dsn)
if err != nil {
t.Fatalf("failed to connect to sqlite: %v", err)
}
// SQLite in-memory DBs are destroyed when the last connection closes.
// Keep at least one open so it survives the test.
sqliteDB.SetMaxOpenConns(1)
t.Cleanup(func() { sqliteDB.Close() })
err = db.RunMigrations(sqliteDB.DB, "sqlite")
if err != nil {
t.Fatalf("failed to run sqlite migrations: %v", err)
}
return DBInfo{DB: sqliteDB, Driver: "sqlite"}
}
func newPostgresDB(t *testing.T, baseURL string) DBInfo {
t.Helper()
// Create a unique schema per test to ensure isolation.
// Create a unique schema per test for isolation.
safeName := strings.ReplaceAll(t.Name(), "/", "_")
safeName = strings.ReplaceAll(safeName, " ", "_")
schema := fmt.Sprintf("test_%s", safeName)
// Connect to the base database to create the schema.
baseDB, err := sqlx.Connect("pgx", baseURL)
if err != nil {
t.Fatalf("failed to connect to postgres: %v", err)
}
_, err = baseDB.Exec(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q", schema))
if err != nil {
if _, err := baseDB.Exec(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q", schema)); err != nil {
baseDB.Close()
t.Fatalf("failed to create schema %s: %v", schema, err)
}
baseDB.Close()
// Connect with a single-connection pool and set search_path to the new schema.
// MaxOpenConns(1) ensures all queries reuse the same connection where
// MaxOpenConns(1) ensures every query reuses the connection where
// search_path is set (SET is session-level in PostgreSQL).
pgDB, err := sqlx.Connect("pgx", baseURL)
if err != nil {
@ -96,15 +64,13 @@ func newPostgresDB(t *testing.T, baseURL string) DBInfo {
}
pgDB.SetMaxOpenConns(1)
_, err = pgDB.Exec(fmt.Sprintf(`SET search_path TO "%s"`, schema))
if err != nil {
if _, err := pgDB.Exec(fmt.Sprintf(`SET search_path TO "%s"`, schema)); err != nil {
pgDB.Close()
t.Fatalf("failed to set search_path to %s: %v", schema, err)
}
t.Cleanup(func() {
pgDB.Close()
// Drop the schema after the test.
cleanDB, err := sqlx.Connect("pgx", baseURL)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %q CASCADE", schema))
@ -112,10 +78,9 @@ func newPostgresDB(t *testing.T, baseURL string) DBInfo {
}
})
err = db.RunMigrations(pgDB.DB, "pgx")
if err != nil {
if err := db.RunMigrations(pgDB.DB); err != nil {
t.Fatalf("failed to run postgres migrations: %v", err)
}
return DBInfo{DB: pgDB, Driver: "pgx"}
return DBInfo{DB: pgDB}
}