feat: transfer funds between accounts
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m32s
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m32s
This commit is contained in:
parent
da718427bd
commit
ff237e2fab
14 changed files with 1186 additions and 60 deletions
|
|
@ -14,8 +14,11 @@ type TransactionRepository interface {
|
|||
CreateDepositAtomic(t *model.Transaction, newBalance decimal.Decimal) error
|
||||
UpdateBillAtomic(t *model.Transaction, newBalance decimal.Decimal, categoryID *string) error
|
||||
UpdateDepositAtomic(t *model.Transaction, newBalance decimal.Decimal) error
|
||||
TransferAtomic(withdrawal, deposit *model.Transaction, sourceNewBalance, destNewBalance decimal.Decimal) error
|
||||
GetByID(id string) (*model.Transaction, error)
|
||||
GetCategoryID(transactionID string) (*string, error)
|
||||
GetRelatedID(transactionID string) (*string, error)
|
||||
TransferIDsIn(ids []string) (map[string]bool, error)
|
||||
ListByAccount(accountID string, limit, offset int) ([]*model.Transaction, error)
|
||||
CountByAccount(accountID string) (int, error)
|
||||
}
|
||||
|
|
@ -138,6 +141,55 @@ func (r *transactionRepository) UpdateDepositAtomic(t *model.Transaction, newBal
|
|||
})
|
||||
}
|
||||
|
||||
// TransferAtomic creates the withdrawal + deposit transaction pair, updates both
|
||||
// account balances, and links the two via related_transactions in a single SQL
|
||||
// transaction. Negative balances are allowed — overdraft enforcement is a product
|
||||
// decision left to the service layer.
|
||||
func (r *transactionRepository) TransferAtomic(withdrawal, deposit *model.Transaction, sourceNewBalance, destNewBalance decimal.Decimal) error {
|
||||
return WithTx(r.db, func(tx *sqlx.Tx) error {
|
||||
insertTxn := `
|
||||
INSERT INTO transactions
|
||||
(id, value, type, account_id, title, description, occurred_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9);
|
||||
`
|
||||
if _, err := tx.Exec(insertTxn,
|
||||
withdrawal.ID, withdrawal.Value, withdrawal.Type, withdrawal.AccountID, withdrawal.Title,
|
||||
withdrawal.Description, withdrawal.OccurredAt, withdrawal.CreatedAt, withdrawal.UpdatedAt,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(insertTxn,
|
||||
deposit.ID, deposit.Value, deposit.Type, deposit.AccountID, deposit.Title,
|
||||
deposit.Description, deposit.OccurredAt, deposit.CreatedAt, deposit.UpdatedAt,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateBalance := `UPDATE accounts SET balance = $1, updated_at = $2 WHERE id = $3;`
|
||||
now := time.Now()
|
||||
if _, err := tx.Exec(updateBalance, sourceNewBalance, now, withdrawal.AccountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateBalance, destNewBalance, now, deposit.AccountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// related_transactions has CHECK (transaction_one_id < transaction_two_id);
|
||||
// order the IDs to satisfy it.
|
||||
one, two := withdrawal.ID, deposit.ID
|
||||
if one > two {
|
||||
one, two = two, one
|
||||
}
|
||||
if _, err := tx.Exec(
|
||||
`INSERT INTO related_transactions (transaction_one_id, transaction_two_id) VALUES ($1, $2);`,
|
||||
one, two,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *transactionRepository) GetByID(id string) (*model.Transaction, error) {
|
||||
query := `
|
||||
SELECT id, value, type, account_id, title, description, occurred_at, created_at, updated_at
|
||||
|
|
@ -151,6 +203,55 @@ func (r *transactionRepository) GetByID(id string) (*model.Transaction, error) {
|
|||
return t, nil
|
||||
}
|
||||
|
||||
// GetRelatedID returns the other half of a transfer pair if `transactionID` is
|
||||
// part of one. Returns (nil, nil) when the transaction is standalone.
|
||||
func (r *transactionRepository) GetRelatedID(transactionID string) (*string, error) {
|
||||
var other string
|
||||
err := r.db.Get(&other, `
|
||||
SELECT CASE
|
||||
WHEN transaction_one_id = $1 THEN transaction_two_id
|
||||
ELSE transaction_one_id
|
||||
END
|
||||
FROM related_transactions
|
||||
WHERE transaction_one_id = $1 OR transaction_two_id = $1
|
||||
LIMIT 1;
|
||||
`, transactionID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &other, nil
|
||||
}
|
||||
|
||||
// TransferIDsIn returns the subset of `ids` that appear in related_transactions
|
||||
// (either side). Used by list pages to decide which rows are non-editable so we
|
||||
// don't N+1 a per-row check.
|
||||
func (r *transactionRepository) TransferIDsIn(ids []string) (map[string]bool, error) {
|
||||
if len(ids) == 0 {
|
||||
return map[string]bool{}, nil
|
||||
}
|
||||
query, args, err := sqlx.In(`
|
||||
SELECT transaction_one_id AS id FROM related_transactions WHERE transaction_one_id IN (?)
|
||||
UNION
|
||||
SELECT transaction_two_id AS id FROM related_transactions WHERE transaction_two_id IN (?)
|
||||
`, ids, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = r.db.Rebind(query)
|
||||
var hits []string
|
||||
if err := r.db.Select(&hits, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(map[string]bool, len(hits))
|
||||
for _, id := range hits {
|
||||
out[id] = true
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *transactionRepository) GetCategoryID(transactionID string) (*string, error) {
|
||||
var id string
|
||||
err := r.db.Get(&id, `SELECT category_id FROM transaction_categories WHERE transaction_id = $1 LIMIT 1;`, transactionID)
|
||||
|
|
|
|||
113
internal/repository/transaction_test.go
Normal file
113
internal/repository/transaction_test.go
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||
"git.juancwu.dev/juancwu/budgit/internal/testutil"
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransactionRepository_TransferAtomic_LinksPairAndUpdatesBalances(t *testing.T) {
|
||||
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||
repo := NewTransactionRepository(dbi.DB)
|
||||
|
||||
user := testutil.CreateTestUser(t, dbi.DB, "transfer-repo@example.com", nil)
|
||||
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||
src := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Src")
|
||||
dst := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Dst")
|
||||
|
||||
now := time.Now()
|
||||
withdrawal := &model.Transaction{
|
||||
ID: uuid.NewString(), Value: decimal.NewFromInt(40), Type: model.TransactionTypeWithdrawal,
|
||||
AccountID: src.ID, Title: "Move", OccurredAt: now, CreatedAt: now, UpdatedAt: now,
|
||||
}
|
||||
deposit := &model.Transaction{
|
||||
ID: uuid.NewString(), Value: decimal.NewFromInt(40), Type: model.TransactionTypeDeposit,
|
||||
AccountID: dst.ID, Title: "Move", OccurredAt: now, CreatedAt: now, UpdatedAt: now,
|
||||
}
|
||||
|
||||
err := repo.TransferAtomic(withdrawal, deposit, decimal.NewFromInt(-40), decimal.NewFromInt(40))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both transactions exist.
|
||||
w, err := repo.GetByID(withdrawal.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.TransactionTypeWithdrawal, w.Type)
|
||||
d, err := repo.GetByID(deposit.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.TransactionTypeDeposit, d.Type)
|
||||
|
||||
// Balances were applied.
|
||||
accountRepo := NewAccountRepository(dbi.DB)
|
||||
srcAfter, err := accountRepo.ByID(src.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, decimal.NewFromInt(-40).Equal(srcAfter.Balance))
|
||||
dstAfter, err := accountRepo.ByID(dst.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, decimal.NewFromInt(40).Equal(dstAfter.Balance))
|
||||
|
||||
// Linked both ways via related_transactions.
|
||||
other, err := repo.GetRelatedID(withdrawal.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, other)
|
||||
assert.Equal(t, deposit.ID, *other)
|
||||
|
||||
other, err = repo.GetRelatedID(deposit.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, other)
|
||||
assert.Equal(t, withdrawal.ID, *other)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionRepository_TransferIDsIn_ReturnsOnlyTransferHalves(t *testing.T) {
|
||||
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||
repo := NewTransactionRepository(dbi.DB)
|
||||
|
||||
user := testutil.CreateTestUser(t, dbi.DB, "transferids@example.com", nil)
|
||||
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||
src := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Src")
|
||||
dst := testutil.CreateTestAccount(t, dbi.DB, space.ID, "Dst")
|
||||
|
||||
// One transfer pair (source `w`, deposit `d`) plus one standalone txn.
|
||||
now := time.Now()
|
||||
w := &model.Transaction{ID: uuid.NewString(), Value: decimal.NewFromInt(5), Type: model.TransactionTypeWithdrawal, AccountID: src.ID, Title: "T-w", OccurredAt: now, CreatedAt: now, UpdatedAt: now}
|
||||
d := &model.Transaction{ID: uuid.NewString(), Value: decimal.NewFromInt(5), Type: model.TransactionTypeDeposit, AccountID: dst.ID, Title: "T-d", OccurredAt: now, CreatedAt: now, UpdatedAt: now}
|
||||
require.NoError(t, repo.TransferAtomic(w, d, decimal.NewFromInt(-5), decimal.NewFromInt(5)))
|
||||
standalone := testutil.CreateTestTransaction(t, dbi.DB, src.ID, "solo", model.TransactionTypeDeposit, decimal.NewFromInt(1))
|
||||
|
||||
hits, err := repo.TransferIDsIn([]string{w.ID, d.ID, standalone.ID})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hits[w.ID], "withdrawal half should be flagged")
|
||||
assert.True(t, hits[d.ID], "deposit half should be flagged")
|
||||
assert.False(t, hits[standalone.ID], "standalone transaction should not be flagged")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionRepository_TransferIDsIn_EmptyInput(t *testing.T) {
|
||||
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||
repo := NewTransactionRepository(dbi.DB)
|
||||
hits, err := repo.TransferIDsIn(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, hits)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionRepository_GetRelatedID_NoneWhenStandalone(t *testing.T) {
|
||||
testutil.ForEachDB(t, func(t *testing.T, dbi testutil.DBInfo) {
|
||||
repo := NewTransactionRepository(dbi.DB)
|
||||
|
||||
user := testutil.CreateTestUser(t, dbi.DB, "standalone@example.com", nil)
|
||||
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
|
||||
acct := testutil.CreateTestAccount(t, dbi.DB, space.ID, "A")
|
||||
txn := testutil.CreateTestTransaction(t, dbi.DB, acct.ID, "x", model.TransactionTypeDeposit, decimal.NewFromInt(1))
|
||||
|
||||
other, err := repo.GetRelatedID(txn.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, other)
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue