feat: disable registration

This commit is contained in:
juancwu 2026-05-17 14:30:59 +00:00
commit 39330ce821
17 changed files with 179 additions and 132 deletions

View file

@ -12,6 +12,9 @@ JWT_SECRET=
# Go duration format
JWT_EXPIRY=168h
# Set to true to block new account creation via magic link. Existing users can still log in.
DISABLE_REGISTRATION=false
MAILER_SMTP_HOST=
MAILER_SMTP_PORT=
MAILER_IMAP_HOST=

View file

@ -11,20 +11,20 @@ import (
)
type App struct {
Cfg *config.Config
DB *sqlx.DB
UserService *service.UserService
AuthService *service.AuthService
EmailService *service.EmailService
SpaceService *service.SpaceService
AccountService *service.AccountService
AllocationService *service.AllocationService
Cfg *config.Config
DB *sqlx.DB
UserService *service.UserService
AuthService *service.AuthService
EmailService *service.EmailService
SpaceService *service.SpaceService
AccountService *service.AccountService
AllocationService *service.AllocationService
TransactionService *service.TransactionService
RecurringEventService *service.RecurringEventService
InviteService *service.InviteService
AuditLogService *service.SpaceAuditLogService
TxAuditLogService *service.TransactionAuditLogService
AccountActivitySvc *service.AccountActivityService
AuditLogService *service.SpaceAuditLogService
TxAuditLogService *service.TransactionAuditLogService
AccountActivitySvc *service.AccountActivityService
}
func New(cfg *config.Config) (*App, error) {
@ -85,25 +85,26 @@ func New(cfg *config.Config) (*App, error) {
cfg.JWTExpiry,
cfg.TokenMagicLinkExpiry,
cfg.IsProduction(),
cfg.DisableRegistration,
)
inviteService := service.NewInviteService(invitationRepository, spaceRepository, userRepository, emailService, auditLogService)
recurringEventService := service.NewRecurringEventService(recurringEventRepository, transactionService, accountService)
return &App{
Cfg: cfg,
DB: database,
UserService: userService,
AuthService: authService,
EmailService: emailService,
SpaceService: spaceService,
AccountService: accountService,
AllocationService: allocationService,
Cfg: cfg,
DB: database,
UserService: userService,
AuthService: authService,
EmailService: emailService,
SpaceService: spaceService,
AccountService: accountService,
AllocationService: allocationService,
TransactionService: transactionService,
RecurringEventService: recurringEventService,
InviteService: inviteService,
AuditLogService: auditLogService,
TxAuditLogService: txAuditLogService,
AccountActivitySvc: accountActivityService,
AuditLogService: auditLogService,
TxAuditLogService: txAuditLogService,
AccountActivitySvc: accountActivityService,
}, nil
}

View file

@ -23,6 +23,8 @@ type Config struct {
JWTExpiry time.Duration
TokenMagicLinkExpiry time.Duration
DisableRegistration bool
MailerSMTPHost string
MailerSMTPPort int
MailerIMAPHost string
@ -58,6 +60,8 @@ func Load(version string) *Config {
JWTExpiry: envDuration("JWT_EXPIRY", 168*time.Hour), // 7 days default
TokenMagicLinkExpiry: envDuration("TOKEN_MAGIC_LINK_EXPIRY", 10*time.Minute),
DisableRegistration: envBool("DISABLE_REGISTRATION", false),
MailerSMTPHost: envString("MAILER_SMTP_HOST", ""),
MailerSMTPPort: envInt("MAILER_SMTP_PORT", 587),
MailerIMAPHost: envString("MAILER_IMAP_HOST", ""),
@ -120,6 +124,19 @@ func envInt(key string, def int) int {
return int(i)
}
func envBool(key string, def bool) bool {
value, ok := os.LookupEnv(key)
if !ok || value == "" {
return def
}
b, err := strconv.ParseBool(value)
if err != nil {
slog.Warn("config invalid bool, using default", "key", key, "value", value, "default", def)
return def
}
return b
}
func envDuration(key string, def time.Duration) time.Duration {
value, ok := os.LookupEnv(key)
if !ok || value == "" {

View file

@ -191,8 +191,8 @@ func (h *allocationHandler) renderSectionWithCreateError(w http.ResponseWriter,
}
ui.Render(w, r, blocks.AllocationsSection(blocks.AllocationsSectionProps{
SpaceID: spaceID, AccountID: accountID, Summary: summary,
CreateForm: &state,
ShowCreateForm: true,
CreateForm: &state,
ShowCreateForm: true,
}))
}

View file

@ -88,6 +88,23 @@ func (h *authHandler) SendMagicLink(w http.ResponseWriter, r *http.Request) {
err = h.authService.SendMagicLink(email)
if err != nil {
slog.Warn("magic link send failed", "error", err, "email", email)
if errors.Is(err, service.ErrRegistrationDisabled) {
msg := "Registration is disabled. Please contact an administrator if you need an account."
if r.URL.Query().Get("resend") == "true" {
ui.RenderToast(w, r, toast.Toast(toast.Props{
Title: "Magic link not sent",
Description: msg,
Variant: toast.VariantError,
Icon: true,
Dismissible: true,
Duration: 5000,
}))
return
}
ui.Render(w, r, pages.Auth(msg))
return
}
}
if r.URL.Query().Get("resend") == "true" {

View file

@ -24,7 +24,7 @@ func newTestAuthHandler(dbi testutil.DBInfo) *authHandler {
spaceSvc := service.NewSpaceService(spaceRepo)
accountSvc := service.NewAccountService(accountRepo)
emailSvc := service.NewEmailService(nil, "test@example.com", "http://localhost:9999", "Budgit Test", false)
authSvc := service.NewAuthService(emailSvc, userRepo, tokenRepo, spaceSvc, accountSvc, cfg.JWTSecret, cfg.JWTExpiry, cfg.TokenMagicLinkExpiry, false)
authSvc := service.NewAuthService(emailSvc, userRepo, tokenRepo, spaceSvc, accountSvc, cfg.JWTSecret, cfg.JWTExpiry, cfg.TokenMagicLinkExpiry, false, false)
inviteSvc := service.NewInviteService(inviteRepo, spaceRepo, userRepo, emailSvc, nil)
return NewAuthHandler(authSvc, inviteSvc, spaceSvc)
}

View file

@ -78,22 +78,22 @@ func (h *recurringEventHandler) CreatePage(w http.ResponseWriter, r *http.Reques
now := time.Now()
formProps := forms.RecurringEventFormProps{
SpaceID: spaceID,
Action: routeurl.URL("action.app.spaces.space.recurring.create", "spaceID", spaceID),
CancelHref: routeurl.URL("page.app.spaces.space.recurring", "spaceID", spaceID),
SubmitLabel: "Create",
Accounts: accounts,
Timezones: timezone.CommonTimezones(),
Kind: string(model.RecurringEventKindBill),
Frequency: string(model.RecurringFrequencyMonthly),
IntervalCount: "1",
FireTime: "09:00",
SpaceID: spaceID,
Action: routeurl.URL("action.app.spaces.space.recurring.create", "spaceID", spaceID),
CancelHref: routeurl.URL("page.app.spaces.space.recurring", "spaceID", spaceID),
SubmitLabel: "Create",
Accounts: accounts,
Timezones: timezone.CommonTimezones(),
Kind: string(model.RecurringEventKindBill),
Frequency: string(model.RecurringFrequencyMonthly),
IntervalCount: "1",
FireTime: "09:00",
Timezone: "UTC",
StartDate: now.Format("2006-01-02"),
BusinessDaysOnly: false,
DayOfMonth: strconv.Itoa(now.Day()),
DayOfWeek: strconv.Itoa(int(now.Weekday())),
MonthOfYear: strconv.Itoa(int(now.Month())),
DayOfMonth: strconv.Itoa(now.Day()),
DayOfWeek: strconv.Itoa(int(now.Weekday())),
MonthOfYear: strconv.Itoa(int(now.Month())),
}
ui.Render(w, r, pages.SpaceCreateRecurringEventPage(pages.SpaceCreateRecurringEventPageProps{
@ -126,19 +126,19 @@ func (h *recurringEventHandler) EditPage(w http.ResponseWriter, r *http.Request)
}
formProps := forms.RecurringEventFormProps{
SpaceID: spaceID,
Action: routeurl.URL("action.app.spaces.space.recurring.event.edit", "spaceID", spaceID, "eventID", eventID),
CancelHref: routeurl.URL("page.app.spaces.space.recurring", "spaceID", spaceID),
SubmitLabel: "Save",
Accounts: accounts,
Timezones: timezone.CommonTimezones(),
Title: ev.Title,
Kind: string(ev.Kind),
SourceAccountID: ev.SourceAccountID,
Amount: ev.Amount.StringFixedBank(2),
Frequency: string(ev.Frequency),
IntervalCount: strconv.Itoa(ev.IntervalCount),
FireTime: formatTimeOfDay(ev.FireHour, ev.FireMinute),
SpaceID: spaceID,
Action: routeurl.URL("action.app.spaces.space.recurring.event.edit", "spaceID", spaceID, "eventID", eventID),
CancelHref: routeurl.URL("page.app.spaces.space.recurring", "spaceID", spaceID),
SubmitLabel: "Save",
Accounts: accounts,
Timezones: timezone.CommonTimezones(),
Title: ev.Title,
Kind: string(ev.Kind),
SourceAccountID: ev.SourceAccountID,
Amount: ev.Amount.StringFixedBank(2),
Frequency: string(ev.Frequency),
IntervalCount: strconv.Itoa(ev.IntervalCount),
FireTime: formatTimeOfDay(ev.FireHour, ev.FireMinute),
Timezone: ev.Timezone,
StartDate: ev.NextRunAt.In(mustLoc(ev.Timezone)).Format("2006-01-02"),
BusinessDaysOnly: ev.BusinessDaysOnly,
@ -215,19 +215,19 @@ func (h *recurringEventHandler) HandleEdit(w http.ResponseWriter, r *http.Reques
}
if _, err := h.recurringService.Update(service.UpdateRecurringEventInput{
ID: eventID,
Kind: parsed.Kind,
SourceAccountID: parsed.SourceAccountID,
Title: parsed.Title,
Amount: parsed.Amount,
Description: parsed.Description,
Frequency: parsed.Frequency,
IntervalCount: parsed.IntervalCount,
DayOfWeek: parsed.DayOfWeek,
DayOfMonth: parsed.DayOfMonth,
MonthOfYear: parsed.MonthOfYear,
FireHour: parsed.FireHour,
FireMinute: parsed.FireMinute,
ID: eventID,
Kind: parsed.Kind,
SourceAccountID: parsed.SourceAccountID,
Title: parsed.Title,
Amount: parsed.Amount,
Description: parsed.Description,
Frequency: parsed.Frequency,
IntervalCount: parsed.IntervalCount,
DayOfWeek: parsed.DayOfWeek,
DayOfMonth: parsed.DayOfMonth,
MonthOfYear: parsed.MonthOfYear,
FireHour: parsed.FireHour,
FireMinute: parsed.FireMinute,
Timezone: parsed.Timezone,
BusinessDaysOnly: parsed.BusinessDaysOnly,
StartDate: parsed.StartDate,
@ -305,19 +305,19 @@ func (h *recurringEventHandler) parseForm(r *http.Request, spaceID string) (serv
businessDaysOnly := r.FormValue("business_days_only") != ""
props := forms.RecurringEventFormProps{
SpaceID: spaceID,
Accounts: accounts,
Timezones: timezone.CommonTimezones(),
Title: title,
Kind: kind,
SourceAccountID: sourceID,
Amount: amountStr,
Description: descriptionStr,
Frequency: frequency,
IntervalCount: intervalStr,
DayOfWeek: dowStr,
DayOfMonth: domStr,
MonthOfYear: moyStr,
SpaceID: spaceID,
Accounts: accounts,
Timezones: timezone.CommonTimezones(),
Title: title,
Kind: kind,
SourceAccountID: sourceID,
Amount: amountStr,
Description: descriptionStr,
Frequency: frequency,
IntervalCount: intervalStr,
DayOfWeek: dowStr,
DayOfMonth: domStr,
MonthOfYear: moyStr,
FireTime: fireTime,
Timezone: tz,
StartDate: startDateStr,

View file

@ -21,7 +21,7 @@ func newTestSettingsHandler(dbi testutil.DBInfo) (*settingsHandler, *service.Aut
spaceSvc := service.NewSpaceService(spaceRepo)
accountSvc := service.NewAccountService(accountRepo)
emailSvc := service.NewEmailService(nil, "test@example.com", "http://localhost:9999", "Budgit Test", false)
authSvc := service.NewAuthService(emailSvc, userRepo, tokenRepo, spaceSvc, accountSvc, cfg.JWTSecret, cfg.JWTExpiry, cfg.TokenMagicLinkExpiry, false)
authSvc := service.NewAuthService(emailSvc, userRepo, tokenRepo, spaceSvc, accountSvc, cfg.JWTSecret, cfg.JWTExpiry, cfg.TokenMagicLinkExpiry, false, false)
userSvc := service.NewUserService(userRepo)
return NewSettingsHandler(authSvc, userSvc), authSvc
}

View file

@ -330,15 +330,15 @@ func (h *spaceHandler) SpaceAccountPage(w http.ResponseWriter, r *http.Request)
}
ui.Render(w, r, pages.SpaceAccountPage(pages.SpaceAccountPageProps{
SpaceID: spaceID,
SpaceName: space.Name,
AccountID: accountID,
AccountName: account.Name,
AccountBalance: account.Balance,
AccountCurrency: account.Currency,
RecentTransactions: recent,
SpaceID: spaceID,
SpaceName: space.Name,
AccountID: accountID,
AccountName: account.Name,
AccountBalance: account.Balance,
AccountCurrency: account.Currency,
RecentTransactions: recent,
NonEditableTransactionIDs: h.nonEditableTransactionIDs(recent),
AllocationSummary: allocSummary,
AllocationSummary: allocSummary,
}))
}
@ -396,16 +396,16 @@ func (h *spaceHandler) SpaceAccountTransactionsPage(w http.ResponseWriter, r *ht
}
ui.Render(w, r, pages.SpaceAccountTransactionsPage(pages.SpaceAccountTransactionsPageProps{
SpaceID: spaceID,
SpaceName: space.Name,
AccountID: accountID,
AccountName: account.Name,
Transactions: txns,
SpaceID: spaceID,
SpaceName: space.Name,
AccountID: accountID,
AccountName: account.Name,
Transactions: txns,
NonEditableTransactionIDs: h.nonEditableTransactionIDs(txns),
CurrentPage: page,
TotalPages: totalPages,
TotalCount: total,
PerPage: perPage,
CurrentPage: page,
TotalPages: totalPages,
TotalCount: total,
PerPage: perPage,
}))
}

View file

@ -5,19 +5,19 @@ import "time"
type SpaceAuditAction string
const (
SpaceAuditActionRenamed SpaceAuditAction = "space.renamed"
SpaceAuditActionDeleted SpaceAuditAction = "space.deleted"
SpaceAuditActionMemberInvited SpaceAuditAction = "member.invited"
SpaceAuditActionMemberJoined SpaceAuditAction = "member.joined"
SpaceAuditActionMemberRemoved SpaceAuditAction = "member.removed"
SpaceAuditActionInviteCancelled SpaceAuditAction = "invite.cancelled"
SpaceAuditActionAccountCreated SpaceAuditAction = "account.created"
SpaceAuditActionAccountRenamed SpaceAuditAction = "account.renamed"
SpaceAuditActionAccountDeleted SpaceAuditAction = "account.deleted"
SpaceAuditActionAccountCurrencyChanged SpaceAuditAction = "account.currency_changed"
SpaceAuditActionAllocationCreated SpaceAuditAction = "allocation.created"
SpaceAuditActionAllocationUpdated SpaceAuditAction = "allocation.updated"
SpaceAuditActionAllocationDeleted SpaceAuditAction = "allocation.deleted"
SpaceAuditActionRenamed SpaceAuditAction = "space.renamed"
SpaceAuditActionDeleted SpaceAuditAction = "space.deleted"
SpaceAuditActionMemberInvited SpaceAuditAction = "member.invited"
SpaceAuditActionMemberJoined SpaceAuditAction = "member.joined"
SpaceAuditActionMemberRemoved SpaceAuditAction = "member.removed"
SpaceAuditActionInviteCancelled SpaceAuditAction = "invite.cancelled"
SpaceAuditActionAccountCreated SpaceAuditAction = "account.created"
SpaceAuditActionAccountRenamed SpaceAuditAction = "account.renamed"
SpaceAuditActionAccountDeleted SpaceAuditAction = "account.deleted"
SpaceAuditActionAccountCurrencyChanged SpaceAuditAction = "account.currency_changed"
SpaceAuditActionAllocationCreated SpaceAuditAction = "allocation.created"
SpaceAuditActionAllocationUpdated SpaceAuditAction = "allocation.updated"
SpaceAuditActionAllocationDeleted SpaceAuditAction = "allocation.deleted"
)
type SpaceAuditLog struct {

View file

@ -26,7 +26,7 @@ func newTestApp(dbi testutil.DBInfo) *app.App {
spaceSvc := service.NewSpaceService(spaceRepo)
accountSvc := service.NewAccountService(accountRepo)
emailSvc := service.NewEmailService(nil, "test@example.com", "http://localhost:9999", "Budgit Test", false)
authSvc := service.NewAuthService(emailSvc, userRepo, tokenRepo, spaceSvc, accountSvc, cfg.JWTSecret, cfg.JWTExpiry, cfg.TokenMagicLinkExpiry, false)
authSvc := service.NewAuthService(emailSvc, userRepo, tokenRepo, spaceSvc, accountSvc, cfg.JWTSecret, cfg.JWTExpiry, cfg.TokenMagicLinkExpiry, false, false)
userSvc := service.NewUserService(userRepo)
inviteSvc := service.NewInviteService(inviteRepo, spaceRepo, userRepo, emailSvc, nil)

View file

@ -27,7 +27,7 @@ func (s *stubSpaceAuditRepo) ListBySpace(_ string, limit, _ int) ([]*model.Space
}
return firstN(s.listSpace, limit), nil
}
func (s *stubSpaceAuditRepo) CountBySpace(string) (int, error) { return s.countSpace, s.err }
func (s *stubSpaceAuditRepo) CountBySpace(string) (int, error) { return s.countSpace, s.err }
func (s *stubSpaceAuditRepo) ListAccountEvents(_ string, limit, _ int) ([]*model.SpaceAuditLogWithActor, error) {
if s.err != nil {
return nil, s.err
@ -99,9 +99,9 @@ func TestAccountActivityService_List_MergesAndSortsByTimestamp(t *testing.T) {
}
txRepo := &stubTxAuditRepo{
listAccount: []*model.TransactionAuditLogWithActor{
txLog(model.TransactionAuditActionEdited, now), // newest overall
txLog(model.TransactionAuditActionEdited, now), // newest overall
txLog(model.TransactionAuditActionCreated, now.Add(-5*time.Minute)),
txLog(model.TransactionAuditActionDeleted, now.Add(-15*time.Minute)), // oldest overall
txLog(model.TransactionAuditActionDeleted, now.Add(-15*time.Minute)), // oldest overall
},
countAccount: 3,
}

View file

@ -20,15 +20,16 @@ import (
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrNoPassword = errors.New("account uses passwordless login. Use magic link")
ErrPasswordsDoNotMatch = errors.New("passwords do not match")
ErrEmailAlreadyExists = errors.New("email already exists")
ErrWeakPassword = errors.New("password must be at least 12 characters")
ErrCommonPassword = errors.New("password is too common, please choose a stronger one")
ErrEmailNotVerified = errors.New("email not verified")
ErrInvalidEmail = errors.New("invalid email address")
ErrNameRequired = errors.New("name is required")
ErrInvalidCredentials = errors.New("invalid email or password")
ErrNoPassword = errors.New("account uses passwordless login. Use magic link")
ErrPasswordsDoNotMatch = errors.New("passwords do not match")
ErrEmailAlreadyExists = errors.New("email already exists")
ErrWeakPassword = errors.New("password must be at least 12 characters")
ErrCommonPassword = errors.New("password is too common, please choose a stronger one")
ErrEmailNotVerified = errors.New("email not verified")
ErrInvalidEmail = errors.New("invalid email address")
ErrNameRequired = errors.New("name is required")
ErrRegistrationDisabled = errors.New("registration is disabled")
)
type AuthService struct {
@ -41,6 +42,7 @@ type AuthService struct {
jwtExpiry time.Duration
tokenMagicLinkExpiry time.Duration
isProduction bool
disableRegistration bool
}
func NewAuthService(
@ -53,6 +55,7 @@ func NewAuthService(
jwtExpiry time.Duration,
tokenMagicLinkExpiry time.Duration,
isProduction bool,
disableRegistration bool,
) *AuthService {
return &AuthService{
emailService: emailService,
@ -64,6 +67,7 @@ func NewAuthService(
jwtExpiry: jwtExpiry,
tokenMagicLinkExpiry: tokenMagicLinkExpiry,
isProduction: isProduction,
disableRegistration: disableRegistration,
}
}
@ -235,6 +239,10 @@ func (s *AuthService) SendMagicLink(email string) error {
if err != nil {
// User doesn't exist - create a new passwordless account
if errors.Is(err, repository.ErrUserNotFound) {
if s.disableRegistration {
slog.Info("registration disabled, refusing to create new user", "email", email)
return ErrRegistrationDisabled
}
now := time.Now()
user = &model.User{
ID: uuid.NewString(),

View file

@ -30,6 +30,7 @@ func newTestAuthService(dbi testutil.DBInfo) *AuthService {
cfg.JWTExpiry,
cfg.TokenMagicLinkExpiry,
false,
false,
)
}

View file

@ -210,9 +210,9 @@ func TestFirstFireOnOrAfter_WeeklyShiftsToTargetDayOfWeek(t *testing.T) {
func TestAddMonths(t *testing.T) {
tests := []struct {
y int
m time.Month
n int
y int
m time.Month
n int
wy int
wm time.Month
}{

View file

@ -11,7 +11,7 @@ import (
)
type fakeSpaceAuditRepo struct {
created []*model.SpaceAuditLog
created []*model.SpaceAuditLog
failNext error
}

View file

@ -16,11 +16,11 @@ import (
// txnFixture builds a fully wired TransactionService against a real DB along with
// the helper repos the tests need to inspect post-state.
type txnFixture struct {
svc *TransactionService
txAudit repository.TransactionAuditLogRepository
accounts repository.AccountRepository
user *model.User
account *model.Account
svc *TransactionService
txAudit repository.TransactionAuditLogRepository
accounts repository.AccountRepository
user *model.User
account *model.Account
}
func newTxnFixture(t *testing.T, dbi testutil.DBInfo) *txnFixture {
@ -149,8 +149,8 @@ func TestTransactionService_UpdateDeposit_RebalancesAndDiffs(t *testing.T) {
assert.Equal(t, model.TransactionAuditActionEdited, logs[0].Action)
var meta struct {
AccountID string `json:"account_id"`
Changes map[string]map[string]any `json:"changes"`
AccountID string `json:"account_id"`
Changes map[string]map[string]any `json:"changes"`
}
require.NoError(t, json.Unmarshal(logs[0].Metadata, &meta))
assert.Equal(t, f.account.ID, meta.AccountID)