feat: investment accounts
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m50s

This commit is contained in:
juancwu 2026-05-22 14:49:57 +00:00
commit 7c24a8302d
25 changed files with 2205 additions and 56 deletions

View file

@ -34,47 +34,126 @@ func (s *AccountService) SetAuditLogger(audit *SpaceAuditLogService) {
s.auditSvc = audit
}
func (s *AccountService) CreateAccount(spaceID, name, currencyCode, actorID string) (*model.Account, error) {
if spaceID == "" {
// CreateAccountInput captures all the fields the caller can set when creating
// an account. isInvestment + investmentSubtype are optional; if isInvestment is
// false the subtype is forced to nil.
type CreateAccountInput struct {
SpaceID string
Name string
CurrencyCode string
IsInvestment bool
InvestmentSubtype string // canonical lowercase string; ignored if IsInvestment is false
ActorID string
}
func (s *AccountService) CreateAccount(input CreateAccountInput) (*model.Account, error) {
if input.SpaceID == "" {
return nil, fmt.Errorf("space id is required")
}
if name == "" {
if input.Name == "" {
return nil, fmt.Errorf("account name cannot be empty")
}
code := currency.Normalize(currencyCode)
code := currency.Normalize(input.CurrencyCode)
if code == "" {
code = currency.Default
}
if !currency.IsValid(code) {
return nil, fmt.Errorf("unsupported currency code: %s", currencyCode)
return nil, fmt.Errorf("unsupported currency code: %s", input.CurrencyCode)
}
var subtypePtr *string
if input.IsInvestment {
sub := input.InvestmentSubtype
if !model.IsValidInvestmentSubtype(sub) {
return nil, fmt.Errorf("invalid investment subtype: %s", sub)
}
subtypePtr = &sub
}
now := time.Now()
account := &model.Account{
ID: uuid.NewString(),
Name: name,
SpaceID: spaceID,
Currency: code,
CreatedAt: now,
UpdatedAt: now,
ID: uuid.NewString(),
Name: input.Name,
SpaceID: input.SpaceID,
Currency: code,
IsInvestment: input.IsInvestment,
InvestmentSubtype: subtypePtr,
CreatedAt: now,
UpdatedAt: now,
}
if err := s.accountRepo.Create(account); err != nil {
return nil, fmt.Errorf("failed to create account: %w", err)
}
meta := map[string]any{
"account_id": account.ID,
"account_name": account.Name,
"currency": account.Currency,
}
if account.IsInvestment {
meta["is_investment"] = true
if subtypePtr != nil {
meta["investment_subtype"] = *subtypePtr
}
}
s.auditSvc.Record(RecordOptions{
SpaceID: spaceID,
ActorID: actorID,
Action: model.SpaceAuditActionAccountCreated,
Metadata: map[string]any{
"account_id": account.ID,
"account_name": account.Name,
"currency": account.Currency,
},
SpaceID: input.SpaceID,
ActorID: input.ActorID,
Action: model.SpaceAuditActionAccountCreated,
Metadata: meta,
})
return account, nil
}
// SetInvestmentFlag toggles the investment flag on an existing account. When
// turning the flag off, the subtype is cleared.
func (s *AccountService) SetInvestmentFlag(accountID string, isInvestment bool, subtype string, actorID string) error {
if accountID == "" {
return fmt.Errorf("account id is required")
}
account, err := s.accountRepo.ByID(accountID)
if err != nil {
return fmt.Errorf("failed to load account: %w", err)
}
var subtypePtr *string
if isInvestment {
if !model.IsValidInvestmentSubtype(subtype) {
return fmt.Errorf("invalid investment subtype: %s", subtype)
}
s := subtype
subtypePtr = &s
}
if err := s.accountRepo.SetInvestment(accountID, isInvestment, subtypePtr); err != nil {
return fmt.Errorf("failed to update investment flag: %w", err)
}
s.auditSvc.Record(RecordOptions{
SpaceID: account.SpaceID,
ActorID: actorID,
Action: model.SpaceAuditActionAccountInvestmentFlag,
Metadata: map[string]any{
"account_id": accountID,
"account_name": account.Name,
"is_investment": isInvestment,
"investment_subtype": subtypePtr,
},
})
return nil
}
// InvestmentAccountsForUser lists every investment-flagged account in spaces
// the user is a member of (including spaces they own).
func (s *AccountService) InvestmentAccountsForUser(userID string) ([]*model.Account, error) {
if userID == "" {
return nil, fmt.Errorf("user id is required")
}
accounts, err := s.accountRepo.InvestmentAccountsByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to list investment accounts: %w", err)
}
return accounts, nil
}
func (s *AccountService) GetAccount(id string) (*model.Account, error) {
account, err := s.accountRepo.ByID(id)
if err != nil {

View file

@ -22,7 +22,7 @@ func TestAccountService_CreateAccount_RecordsAudit(t *testing.T) {
user := testutil.CreateTestUser(t, dbi.DB, "acct-create-audit@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account, err := svc.CreateAccount(space.ID, "Checking", "CAD", user.ID)
account, err := svc.CreateAccount(CreateAccountInput{SpaceID: space.ID, Name: "Checking", CurrencyCode: "CAD", ActorID: user.ID})
require.NoError(t, err)
logs, err := auditRepo.ListAccountEvents(account.ID, 10, 0)
@ -104,7 +104,7 @@ func TestAccountService_NoAuditLoggerSet_DoesNotPanic(t *testing.T) {
user := testutil.CreateTestUser(t, dbi.DB, "no-audit@example.com", nil)
space := testutil.CreateTestSpace(t, dbi.DB, user.ID, "S")
account, err := svc.CreateAccount(space.ID, "x", "", user.ID)
account, err := svc.CreateAccount(CreateAccountInput{SpaceID: space.ID, Name: "x", ActorID: user.ID})
require.NoError(t, err)
require.NoError(t, svc.RenameAccount(account.ID, "y", user.ID))
require.NoError(t, svc.DeleteAccount(account.ID, user.ID))

View file

@ -371,7 +371,11 @@ func (s *AuthService) CompleteOnboarding(userID, name string) error {
return fmt.Errorf("failed to create onboarding space: %w", err)
}
if _, err := s.accountService.CreateAccount(space.ID, DefaultAccountName, "", userID); err != nil {
if _, err := s.accountService.CreateAccount(CreateAccountInput{
SpaceID: space.ID,
Name: DefaultAccountName,
ActorID: userID,
}); err != nil {
if delErr := s.spaceService.DeleteSpace(space.ID, userID); delErr != nil {
slog.Error("failed to roll back space after account creation error",
"space_id", space.ID, "error", delErr)

View file

@ -0,0 +1,355 @@
package service
import (
"fmt"
"time"
"git.juancwu.dev/juancwu/budgit/internal/model"
"git.juancwu.dev/juancwu/budgit/internal/repository"
"github.com/google/uuid"
"github.com/shopspring/decimal"
)
// InvestmentService handles contribution rooms, holdings, trades, and the
// summary view for investment-flagged accounts. Cash movement (contributions
// and withdrawals) still goes through TransactionService / TransferService;
// this service reads from those tables but never mutates them directly.
type InvestmentService struct {
accountRepo repository.AccountRepository
roomRepo repository.InvestmentContributionRoomRepository
holdingRepo repository.InvestmentHoldingRepository
tradeRepo repository.InvestmentTradeRepository
txRepo repository.TransactionRepository
}
func NewInvestmentService(
accountRepo repository.AccountRepository,
roomRepo repository.InvestmentContributionRoomRepository,
holdingRepo repository.InvestmentHoldingRepository,
tradeRepo repository.InvestmentTradeRepository,
txRepo repository.TransactionRepository,
) *InvestmentService {
return &InvestmentService{
accountRepo: accountRepo,
roomRepo: roomRepo,
holdingRepo: holdingRepo,
tradeRepo: tradeRepo,
txRepo: txRepo,
}
}
// ---------- Contribution room ----------
func (s *InvestmentService) SetContributionRoom(accountID string, year int, room decimal.Decimal) error {
if accountID == "" {
return fmt.Errorf("account id is required")
}
if year < 1900 || year > 9999 {
return fmt.Errorf("year out of range")
}
if room.IsNegative() {
return fmt.Errorf("contribution room cannot be negative")
}
account, err := s.accountRepo.ByID(accountID)
if err != nil {
return fmt.Errorf("failed to load account: %w", err)
}
if !account.IsInvestment {
return fmt.Errorf("account is not an investment account")
}
now := time.Now()
return s.roomRepo.Upsert(&model.InvestmentContributionRoom{
AccountID: accountID,
Year: year,
RoomAmount: room,
CreatedAt: now,
UpdatedAt: now,
})
}
func (s *InvestmentService) GetContributionRoom(accountID string, year int) (*model.InvestmentContributionRoom, error) {
room, err := s.roomRepo.ByAccountAndYear(accountID, year)
if err != nil {
if err == repository.ErrContributionRoomNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to load contribution room: %w", err)
}
return room, nil
}
func (s *InvestmentService) ListContributionRooms(accountID string) ([]*model.InvestmentContributionRoom, error) {
rooms, err := s.roomRepo.ByAccountID(accountID)
if err != nil {
return nil, fmt.Errorf("failed to list contribution rooms: %w", err)
}
return rooms, nil
}
// ---------- Summary ----------
// SummarizeAccount produces the rollup view for an investment account in the
// given calendar year: contribution room, YTD cash flow, lifetime net
// contributions, and total cost basis across all holdings.
func (s *InvestmentService) SummarizeAccount(accountID string, year int) (*model.InvestmentAccountSummary, error) {
account, err := s.accountRepo.ByID(accountID)
if err != nil {
return nil, fmt.Errorf("failed to load account: %w", err)
}
if !account.IsInvestment {
return nil, fmt.Errorf("account is not an investment account")
}
ytdContrib, err := s.txRepo.SumByAccountYearType(accountID, year, model.TransactionTypeDeposit)
if err != nil {
return nil, fmt.Errorf("failed to sum ytd contributions: %w", err)
}
ytdWithdraw, err := s.txRepo.SumByAccountYearType(accountID, year, model.TransactionTypeWithdrawal)
if err != nil {
return nil, fmt.Errorf("failed to sum ytd withdrawals: %w", err)
}
lifeContrib, err := s.txRepo.SumLifetimeByAccountType(accountID, model.TransactionTypeDeposit)
if err != nil {
return nil, fmt.Errorf("failed to sum lifetime contributions: %w", err)
}
lifeWithdraw, err := s.txRepo.SumLifetimeByAccountType(accountID, model.TransactionTypeWithdrawal)
if err != nil {
return nil, fmt.Errorf("failed to sum lifetime withdrawals: %w", err)
}
summary := &model.InvestmentAccountSummary{
Account: account,
Year: year,
YTDContributions: ytdContrib,
YTDWithdrawals: ytdWithdraw,
NetContributions: lifeContrib.Sub(lifeWithdraw),
}
room, err := s.roomRepo.ByAccountAndYear(accountID, year)
if err == nil {
amt := room.RoomAmount
summary.RoomAmount = &amt
rem := amt.Sub(ytdContrib)
summary.RoomRemaining = &rem
} else if err != repository.ErrContributionRoomNotFound {
return nil, fmt.Errorf("failed to load contribution room: %w", err)
}
positions, err := s.HoldingPositions(accountID)
if err != nil {
return nil, err
}
summary.HoldingCount = len(positions)
for _, p := range positions {
summary.TotalCostBasis = summary.TotalCostBasis.Add(p.CostBasis)
}
return summary, nil
}
// ---------- Holdings ----------
func (s *InvestmentService) CreateHolding(accountID, symbol, displayName string) (*model.InvestmentHolding, error) {
if accountID == "" {
return nil, fmt.Errorf("account id is required")
}
if symbol == "" {
return nil, fmt.Errorf("symbol is required")
}
if displayName == "" {
displayName = symbol
}
account, err := s.accountRepo.ByID(accountID)
if err != nil {
return nil, fmt.Errorf("failed to load account: %w", err)
}
if !account.IsInvestment {
return nil, fmt.Errorf("account is not an investment account")
}
now := time.Now()
holding := &model.InvestmentHolding{
ID: uuid.NewString(),
AccountID: accountID,
Symbol: symbol,
DisplayName: displayName,
CreatedAt: now,
UpdatedAt: now,
}
if err := s.holdingRepo.Create(holding); err != nil {
return nil, fmt.Errorf("failed to create holding: %w", err)
}
return holding, nil
}
func (s *InvestmentService) GetHolding(id string) (*model.InvestmentHolding, error) {
h, err := s.holdingRepo.ByID(id)
if err != nil {
return nil, fmt.Errorf("failed to load holding: %w", err)
}
return h, nil
}
func (s *InvestmentService) UpdateHolding(id, symbol, displayName string) error {
if symbol == "" {
return fmt.Errorf("symbol is required")
}
if displayName == "" {
displayName = symbol
}
if err := s.holdingRepo.Update(id, symbol, displayName); err != nil {
return fmt.Errorf("failed to update holding: %w", err)
}
return nil
}
func (s *InvestmentService) DeleteHolding(id string) error {
return s.holdingRepo.Delete(id)
}
func (s *InvestmentService) ListHoldings(accountID string) ([]*model.InvestmentHolding, error) {
return s.holdingRepo.ByAccountID(accountID)
}
// ---------- Trades ----------
type RecordTradeInput struct {
HoldingID string
Type model.InvestmentTradeType
Quantity decimal.Decimal
PricePerUnit decimal.Decimal
Fees *decimal.Decimal
OccurredAt time.Time
Notes *string
}
func (s *InvestmentService) RecordTrade(input RecordTradeInput) (*model.InvestmentTrade, error) {
if input.HoldingID == "" {
return nil, fmt.Errorf("holding id is required")
}
if !model.IsValidInvestmentTradeType(string(input.Type)) {
return nil, fmt.Errorf("invalid trade type: %s", input.Type)
}
if !input.Quantity.IsPositive() {
return nil, fmt.Errorf("quantity must be greater than zero")
}
if input.PricePerUnit.IsNegative() {
return nil, fmt.Errorf("price per unit cannot be negative")
}
if input.OccurredAt.IsZero() {
input.OccurredAt = time.Now()
}
trade := &model.InvestmentTrade{
ID: uuid.NewString(),
HoldingID: input.HoldingID,
Type: input.Type,
Quantity: input.Quantity,
PricePerUnit: input.PricePerUnit,
Fees: input.Fees,
OccurredAt: input.OccurredAt,
Notes: input.Notes,
CreatedAt: time.Now(),
}
if err := s.tradeRepo.Create(trade); err != nil {
return nil, fmt.Errorf("failed to record trade: %w", err)
}
return trade, nil
}
func (s *InvestmentService) UpdateTrade(id string, qty, price decimal.Decimal, fees *decimal.Decimal, occurredAt time.Time, notes *string) error {
if !qty.IsPositive() {
return fmt.Errorf("quantity must be greater than zero")
}
if price.IsNegative() {
return fmt.Errorf("price per unit cannot be negative")
}
return s.tradeRepo.Update(id, qty, price, fees, occurredAt, notes)
}
func (s *InvestmentService) DeleteTrade(id string) error {
return s.tradeRepo.Delete(id)
}
func (s *InvestmentService) GetTrade(id string) (*model.InvestmentTrade, error) {
return s.tradeRepo.ByID(id)
}
func (s *InvestmentService) ListTrades(holdingID string) ([]*model.InvestmentTrade, error) {
return s.tradeRepo.ByHoldingID(holdingID)
}
// HoldingPositions returns the derived position for every holding in the
// account. Positions are computed by replaying each trade in chronological
// order, maintaining a running weighted-average cost basis. Each sell reduces
// the remaining quantity at the current avg cost; realized P/L accumulates on
// each sell as (sell.price avg cost) × qty fees.
func (s *InvestmentService) HoldingPositions(accountID string) ([]model.HoldingPosition, error) {
holdings, err := s.holdingRepo.ByAccountID(accountID)
if err != nil {
return nil, fmt.Errorf("failed to load holdings: %w", err)
}
out := make([]model.HoldingPosition, 0, len(holdings))
for _, h := range holdings {
pos, err := s.holdingPosition(*h)
if err != nil {
return nil, err
}
out = append(out, pos)
}
return out, nil
}
func (s *InvestmentService) HoldingPosition(holdingID string) (*model.HoldingPosition, error) {
h, err := s.holdingRepo.ByID(holdingID)
if err != nil {
return nil, fmt.Errorf("failed to load holding: %w", err)
}
pos, err := s.holdingPosition(*h)
if err != nil {
return nil, err
}
return &pos, nil
}
func (s *InvestmentService) holdingPosition(h model.InvestmentHolding) (model.HoldingPosition, error) {
trades, err := s.tradeRepo.ByHoldingID(h.ID)
if err != nil {
return model.HoldingPosition{}, fmt.Errorf("failed to load trades: %w", err)
}
pos := model.HoldingPosition{Holding: h}
qty := decimal.Zero
avgCost := decimal.Zero
for _, t := range trades {
fees := decimal.Zero
if t.Fees != nil {
fees = *t.Fees
}
pos.TotalFees = pos.TotalFees.Add(fees)
switch t.Type {
case model.InvestmentTradeTypeBuy:
newQty := qty.Add(t.Quantity)
if newQty.IsPositive() {
// weighted average including fees in cost basis
newCost := qty.Mul(avgCost).Add(t.Quantity.Mul(t.PricePerUnit)).Add(fees)
avgCost = newCost.Div(newQty)
}
qty = newQty
pos.TotalBuyQty = pos.TotalBuyQty.Add(t.Quantity)
price := t.PricePerUnit
pos.LastBuyPrice = &price
case model.InvestmentTradeTypeSell:
realized := t.PricePerUnit.Sub(avgCost).Mul(t.Quantity).Sub(fees)
pos.RealizedPL = pos.RealizedPL.Add(realized)
qty = qty.Sub(t.Quantity)
pos.TotalSellQty = pos.TotalSellQty.Add(t.Quantity)
price := t.PricePerUnit
pos.LastSellPrice = &price
if !qty.IsPositive() {
qty = decimal.Zero
avgCost = decimal.Zero
}
}
}
pos.Quantity = qty
pos.AvgCost = avgCost
pos.CostBasis = qty.Mul(avgCost)
return pos, nil
}