feat: investment accounts
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m50s
All checks were successful
Deploy / build-and-deploy (push) Successful in 1m50s
This commit is contained in:
parent
f444a074bc
commit
7c24a8302d
25 changed files with 2205 additions and 56 deletions
355
internal/service/investment.go
Normal file
355
internal/service/investment.go
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/budgit/internal/model"
|
||||
"git.juancwu.dev/juancwu/budgit/internal/repository"
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
// InvestmentService handles contribution rooms, holdings, trades, and the
|
||||
// summary view for investment-flagged accounts. Cash movement (contributions
|
||||
// and withdrawals) still goes through TransactionService / TransferService;
|
||||
// this service reads from those tables but never mutates them directly.
|
||||
type InvestmentService struct {
|
||||
accountRepo repository.AccountRepository
|
||||
roomRepo repository.InvestmentContributionRoomRepository
|
||||
holdingRepo repository.InvestmentHoldingRepository
|
||||
tradeRepo repository.InvestmentTradeRepository
|
||||
txRepo repository.TransactionRepository
|
||||
}
|
||||
|
||||
func NewInvestmentService(
|
||||
accountRepo repository.AccountRepository,
|
||||
roomRepo repository.InvestmentContributionRoomRepository,
|
||||
holdingRepo repository.InvestmentHoldingRepository,
|
||||
tradeRepo repository.InvestmentTradeRepository,
|
||||
txRepo repository.TransactionRepository,
|
||||
) *InvestmentService {
|
||||
return &InvestmentService{
|
||||
accountRepo: accountRepo,
|
||||
roomRepo: roomRepo,
|
||||
holdingRepo: holdingRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
txRepo: txRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Contribution room ----------
|
||||
|
||||
func (s *InvestmentService) SetContributionRoom(accountID string, year int, room decimal.Decimal) error {
|
||||
if accountID == "" {
|
||||
return fmt.Errorf("account id is required")
|
||||
}
|
||||
if year < 1900 || year > 9999 {
|
||||
return fmt.Errorf("year out of range")
|
||||
}
|
||||
if room.IsNegative() {
|
||||
return fmt.Errorf("contribution room cannot be negative")
|
||||
}
|
||||
account, err := s.accountRepo.ByID(accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load account: %w", err)
|
||||
}
|
||||
if !account.IsInvestment {
|
||||
return fmt.Errorf("account is not an investment account")
|
||||
}
|
||||
now := time.Now()
|
||||
return s.roomRepo.Upsert(&model.InvestmentContributionRoom{
|
||||
AccountID: accountID,
|
||||
Year: year,
|
||||
RoomAmount: room,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InvestmentService) GetContributionRoom(accountID string, year int) (*model.InvestmentContributionRoom, error) {
|
||||
room, err := s.roomRepo.ByAccountAndYear(accountID, year)
|
||||
if err != nil {
|
||||
if err == repository.ErrContributionRoomNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to load contribution room: %w", err)
|
||||
}
|
||||
return room, nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) ListContributionRooms(accountID string) ([]*model.InvestmentContributionRoom, error) {
|
||||
rooms, err := s.roomRepo.ByAccountID(accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list contribution rooms: %w", err)
|
||||
}
|
||||
return rooms, nil
|
||||
}
|
||||
|
||||
// ---------- Summary ----------
|
||||
|
||||
// SummarizeAccount produces the rollup view for an investment account in the
|
||||
// given calendar year: contribution room, YTD cash flow, lifetime net
|
||||
// contributions, and total cost basis across all holdings.
|
||||
func (s *InvestmentService) SummarizeAccount(accountID string, year int) (*model.InvestmentAccountSummary, error) {
|
||||
account, err := s.accountRepo.ByID(accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load account: %w", err)
|
||||
}
|
||||
if !account.IsInvestment {
|
||||
return nil, fmt.Errorf("account is not an investment account")
|
||||
}
|
||||
|
||||
ytdContrib, err := s.txRepo.SumByAccountYearType(accountID, year, model.TransactionTypeDeposit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sum ytd contributions: %w", err)
|
||||
}
|
||||
ytdWithdraw, err := s.txRepo.SumByAccountYearType(accountID, year, model.TransactionTypeWithdrawal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sum ytd withdrawals: %w", err)
|
||||
}
|
||||
lifeContrib, err := s.txRepo.SumLifetimeByAccountType(accountID, model.TransactionTypeDeposit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sum lifetime contributions: %w", err)
|
||||
}
|
||||
lifeWithdraw, err := s.txRepo.SumLifetimeByAccountType(accountID, model.TransactionTypeWithdrawal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sum lifetime withdrawals: %w", err)
|
||||
}
|
||||
|
||||
summary := &model.InvestmentAccountSummary{
|
||||
Account: account,
|
||||
Year: year,
|
||||
YTDContributions: ytdContrib,
|
||||
YTDWithdrawals: ytdWithdraw,
|
||||
NetContributions: lifeContrib.Sub(lifeWithdraw),
|
||||
}
|
||||
|
||||
room, err := s.roomRepo.ByAccountAndYear(accountID, year)
|
||||
if err == nil {
|
||||
amt := room.RoomAmount
|
||||
summary.RoomAmount = &amt
|
||||
rem := amt.Sub(ytdContrib)
|
||||
summary.RoomRemaining = &rem
|
||||
} else if err != repository.ErrContributionRoomNotFound {
|
||||
return nil, fmt.Errorf("failed to load contribution room: %w", err)
|
||||
}
|
||||
|
||||
positions, err := s.HoldingPositions(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summary.HoldingCount = len(positions)
|
||||
for _, p := range positions {
|
||||
summary.TotalCostBasis = summary.TotalCostBasis.Add(p.CostBasis)
|
||||
}
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// ---------- Holdings ----------
|
||||
|
||||
func (s *InvestmentService) CreateHolding(accountID, symbol, displayName string) (*model.InvestmentHolding, error) {
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("account id is required")
|
||||
}
|
||||
if symbol == "" {
|
||||
return nil, fmt.Errorf("symbol is required")
|
||||
}
|
||||
if displayName == "" {
|
||||
displayName = symbol
|
||||
}
|
||||
account, err := s.accountRepo.ByID(accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load account: %w", err)
|
||||
}
|
||||
if !account.IsInvestment {
|
||||
return nil, fmt.Errorf("account is not an investment account")
|
||||
}
|
||||
now := time.Now()
|
||||
holding := &model.InvestmentHolding{
|
||||
ID: uuid.NewString(),
|
||||
AccountID: accountID,
|
||||
Symbol: symbol,
|
||||
DisplayName: displayName,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := s.holdingRepo.Create(holding); err != nil {
|
||||
return nil, fmt.Errorf("failed to create holding: %w", err)
|
||||
}
|
||||
return holding, nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) GetHolding(id string) (*model.InvestmentHolding, error) {
|
||||
h, err := s.holdingRepo.ByID(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load holding: %w", err)
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) UpdateHolding(id, symbol, displayName string) error {
|
||||
if symbol == "" {
|
||||
return fmt.Errorf("symbol is required")
|
||||
}
|
||||
if displayName == "" {
|
||||
displayName = symbol
|
||||
}
|
||||
if err := s.holdingRepo.Update(id, symbol, displayName); err != nil {
|
||||
return fmt.Errorf("failed to update holding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) DeleteHolding(id string) error {
|
||||
return s.holdingRepo.Delete(id)
|
||||
}
|
||||
|
||||
func (s *InvestmentService) ListHoldings(accountID string) ([]*model.InvestmentHolding, error) {
|
||||
return s.holdingRepo.ByAccountID(accountID)
|
||||
}
|
||||
|
||||
// ---------- Trades ----------
|
||||
|
||||
type RecordTradeInput struct {
|
||||
HoldingID string
|
||||
Type model.InvestmentTradeType
|
||||
Quantity decimal.Decimal
|
||||
PricePerUnit decimal.Decimal
|
||||
Fees *decimal.Decimal
|
||||
OccurredAt time.Time
|
||||
Notes *string
|
||||
}
|
||||
|
||||
func (s *InvestmentService) RecordTrade(input RecordTradeInput) (*model.InvestmentTrade, error) {
|
||||
if input.HoldingID == "" {
|
||||
return nil, fmt.Errorf("holding id is required")
|
||||
}
|
||||
if !model.IsValidInvestmentTradeType(string(input.Type)) {
|
||||
return nil, fmt.Errorf("invalid trade type: %s", input.Type)
|
||||
}
|
||||
if !input.Quantity.IsPositive() {
|
||||
return nil, fmt.Errorf("quantity must be greater than zero")
|
||||
}
|
||||
if input.PricePerUnit.IsNegative() {
|
||||
return nil, fmt.Errorf("price per unit cannot be negative")
|
||||
}
|
||||
if input.OccurredAt.IsZero() {
|
||||
input.OccurredAt = time.Now()
|
||||
}
|
||||
trade := &model.InvestmentTrade{
|
||||
ID: uuid.NewString(),
|
||||
HoldingID: input.HoldingID,
|
||||
Type: input.Type,
|
||||
Quantity: input.Quantity,
|
||||
PricePerUnit: input.PricePerUnit,
|
||||
Fees: input.Fees,
|
||||
OccurredAt: input.OccurredAt,
|
||||
Notes: input.Notes,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.tradeRepo.Create(trade); err != nil {
|
||||
return nil, fmt.Errorf("failed to record trade: %w", err)
|
||||
}
|
||||
return trade, nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) UpdateTrade(id string, qty, price decimal.Decimal, fees *decimal.Decimal, occurredAt time.Time, notes *string) error {
|
||||
if !qty.IsPositive() {
|
||||
return fmt.Errorf("quantity must be greater than zero")
|
||||
}
|
||||
if price.IsNegative() {
|
||||
return fmt.Errorf("price per unit cannot be negative")
|
||||
}
|
||||
return s.tradeRepo.Update(id, qty, price, fees, occurredAt, notes)
|
||||
}
|
||||
|
||||
func (s *InvestmentService) DeleteTrade(id string) error {
|
||||
return s.tradeRepo.Delete(id)
|
||||
}
|
||||
|
||||
func (s *InvestmentService) GetTrade(id string) (*model.InvestmentTrade, error) {
|
||||
return s.tradeRepo.ByID(id)
|
||||
}
|
||||
|
||||
func (s *InvestmentService) ListTrades(holdingID string) ([]*model.InvestmentTrade, error) {
|
||||
return s.tradeRepo.ByHoldingID(holdingID)
|
||||
}
|
||||
|
||||
// HoldingPositions returns the derived position for every holding in the
|
||||
// account. Positions are computed by replaying each trade in chronological
|
||||
// order, maintaining a running weighted-average cost basis. Each sell reduces
|
||||
// the remaining quantity at the current avg cost; realized P/L accumulates on
|
||||
// each sell as (sell.price − avg cost) × qty − fees.
|
||||
func (s *InvestmentService) HoldingPositions(accountID string) ([]model.HoldingPosition, error) {
|
||||
holdings, err := s.holdingRepo.ByAccountID(accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load holdings: %w", err)
|
||||
}
|
||||
out := make([]model.HoldingPosition, 0, len(holdings))
|
||||
for _, h := range holdings {
|
||||
pos, err := s.holdingPosition(*h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, pos)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) HoldingPosition(holdingID string) (*model.HoldingPosition, error) {
|
||||
h, err := s.holdingRepo.ByID(holdingID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load holding: %w", err)
|
||||
}
|
||||
pos, err := s.holdingPosition(*h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pos, nil
|
||||
}
|
||||
|
||||
func (s *InvestmentService) holdingPosition(h model.InvestmentHolding) (model.HoldingPosition, error) {
|
||||
trades, err := s.tradeRepo.ByHoldingID(h.ID)
|
||||
if err != nil {
|
||||
return model.HoldingPosition{}, fmt.Errorf("failed to load trades: %w", err)
|
||||
}
|
||||
pos := model.HoldingPosition{Holding: h}
|
||||
qty := decimal.Zero
|
||||
avgCost := decimal.Zero
|
||||
for _, t := range trades {
|
||||
fees := decimal.Zero
|
||||
if t.Fees != nil {
|
||||
fees = *t.Fees
|
||||
}
|
||||
pos.TotalFees = pos.TotalFees.Add(fees)
|
||||
switch t.Type {
|
||||
case model.InvestmentTradeTypeBuy:
|
||||
newQty := qty.Add(t.Quantity)
|
||||
if newQty.IsPositive() {
|
||||
// weighted average including fees in cost basis
|
||||
newCost := qty.Mul(avgCost).Add(t.Quantity.Mul(t.PricePerUnit)).Add(fees)
|
||||
avgCost = newCost.Div(newQty)
|
||||
}
|
||||
qty = newQty
|
||||
pos.TotalBuyQty = pos.TotalBuyQty.Add(t.Quantity)
|
||||
price := t.PricePerUnit
|
||||
pos.LastBuyPrice = &price
|
||||
case model.InvestmentTradeTypeSell:
|
||||
realized := t.PricePerUnit.Sub(avgCost).Mul(t.Quantity).Sub(fees)
|
||||
pos.RealizedPL = pos.RealizedPL.Add(realized)
|
||||
qty = qty.Sub(t.Quantity)
|
||||
pos.TotalSellQty = pos.TotalSellQty.Add(t.Quantity)
|
||||
price := t.PricePerUnit
|
||||
pos.LastSellPrice = &price
|
||||
if !qty.IsPositive() {
|
||||
qty = decimal.Zero
|
||||
avgCost = decimal.Zero
|
||||
}
|
||||
}
|
||||
}
|
||||
pos.Quantity = qty
|
||||
pos.AvgCost = avgCost
|
||||
pos.CostBasis = qty.Mul(avgCost)
|
||||
return pos, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue