diff --git a/.env.include.example b/.env.include.example index 8ce00b8..6f3bcb2 100644 --- a/.env.include.example +++ b/.env.include.example @@ -19,6 +19,5 @@ MAILER_IMAP_PORT= MAILER_USERNAME= MAILER_PASSWORD= MAILER_EMAIL_FROM= -MAILER_ENVELOPE_FROM= -MAILER_SUPPORT_EMAIL= -MAILER_SUPPORT_ENVELOPE_FROM= + +SUPPORT_EMAIL= diff --git a/internal/app/app.go b/internal/app/app.go index 3c1abcb..1f1ca65 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -11,11 +11,12 @@ import ( ) type App struct { - Cfg *config.Config - DB *sqlx.DB - UserService *service.UserService - AuthService *service.AuthService - EmailService *service.EmailService + Cfg *config.Config + DB *sqlx.DB + UserService *service.UserService + AuthService *service.AuthService + EmailService *service.EmailService + ProfileService *service.ProfileService } func New(cfg *config.Config) (*App, error) { @@ -32,17 +33,36 @@ func New(cfg *config.Config) (*App, error) { emailClient := service.NewEmailClient(cfg.MailerSMTPHost, cfg.MailerSMTPPort, cfg.MailerIMAPHost, cfg.MailerIMAPPort, cfg.MailerUsername, cfg.MailerPassword) userRepository := repository.NewUserRepository(database) + profileRepository := repository.NewProfileRepository(database) + tokenRepository := repository.NewTokenRepository(database) userService := service.NewUserService(userRepository) - authService := service.NewAuthService(userRepository) - emailService := service.NewEmailService(emailClient, cfg.MailerEmailFrom, cfg.MailerEnvelopeFrom, cfg.MailerSupportFrom, cfg.MailerSupportEnvelopeFrom, cfg.AppURL, cfg.AppName, cfg.AppEnv == "development") + emailService := service.NewEmailService( + emailClient, + cfg.MailerEmailFrom, + cfg.AppURL, + cfg.AppName, + cfg.IsProduction(), + ) + authService := service.NewAuthService( + emailService, + userRepository, + profileRepository, + tokenRepository, + cfg.JWTSecret, + cfg.JWTExpiry, + cfg.TokenMagicLinkExpiry, + cfg.IsProduction(), + ) + profileService := service.NewProfileService(profileRepository) return &App{ - Cfg: cfg, - DB: database, - UserService: userService, - AuthService: authService, - EmailService: emailService, + Cfg: cfg, + DB: database, + UserService: userService, + AuthService: authService, + EmailService: emailService, + ProfileService: profileService, }, nil } diff --git a/internal/config/config.go b/internal/config/config.go index 09f9f07..6c31fbe 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,19 +20,19 @@ type Config struct { DBDriver string DBConnection string - JWTSecret string - JWTExpiry time.Duration + JWTSecret string + JWTExpiry time.Duration + TokenMagicLinkExpiry time.Duration - MailerSMTPHost string - MailerSMTPPort int - MailerIMAPHost string - MailerIMAPPort int - MailerUsername string - MailerPassword string - MailerEmailFrom string - MailerEnvelopeFrom string - MailerSupportFrom string - MailerSupportEnvelopeFrom string + MailerSMTPHost string + MailerSMTPPort int + MailerIMAPHost string + MailerIMAPPort int + MailerUsername string + MailerPassword string + MailerEmailFrom string + + SupportEmail string } func Load() *Config { @@ -52,19 +52,19 @@ func Load() *Config { DBDriver: envString("DB_DRIVER", "sqlite"), DBConnection: envString("DB_CONNECTION", "./data/local.db?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)"), - JWTSecret: envRequired("JWT_SECRET"), - JWTExpiry: envDuration("JWT_EXPIRY", 168*time.Hour), // 7 days default + JWTSecret: envRequired("JWT_SECRET"), + JWTExpiry: envDuration("JWT_EXPIRY", 168*time.Hour), // 7 days default + TokenMagicLinkExpiry: envDuration("TOKEN_MAGIC_LINK_EXPIRY", 10*time.Minute), - MailerSMTPHost: envString("MAILER_SMTP_HOST", ""), - MailerSMTPPort: envInt("MAILER_SMTP_PORT", 587), - MailerIMAPHost: envString("MAILER_IMAP_HOST", ""), - MailerIMAPPort: envInt("MAILER_IMAP_PORT", 993), - MailerUsername: envString("MAILER_USERNAME", ""), - MailerPassword: envString("MAILER_PASSWORD", ""), - MailerEmailFrom: envString("MAILER_EMAIL_FROM", ""), - MailerEnvelopeFrom: envString("MAILER_ENVELOPE_FROM", ""), - MailerSupportFrom: envString("MAILER_SUPPORT_EMAIL_FROM", ""), - MailerSupportEnvelopeFrom: envString("MAILER_SUPPORT_ENVELOPE_FROM", ""), + MailerSMTPHost: envString("MAILER_SMTP_HOST", ""), + MailerSMTPPort: envInt("MAILER_SMTP_PORT", 587), + MailerIMAPHost: envString("MAILER_IMAP_HOST", ""), + MailerIMAPPort: envInt("MAILER_IMAP_PORT", 993), + MailerUsername: envString("MAILER_USERNAME", ""), + MailerPassword: envString("MAILER_PASSWORD", ""), + MailerEmailFrom: envString("MAILER_EMAIL_FROM", ""), + + SupportEmail: envString("SUPPORT_EMAIL", ""), } return cfg @@ -85,8 +85,8 @@ func (c *Config) Sanitized() *Config { Port: c.Port, AppTagline: c.AppTagline, - MailerEmailFrom: c.MailerEmailFrom, - MailerEnvelopeFrom: c.MailerEnvelopeFrom, + MailerEmailFrom: c.MailerEmailFrom, + SupportEmail: c.SupportEmail, } } diff --git a/internal/db/migrations/00001_create_users_table.sql b/internal/db/migrations/00001_create_users_table.sql index d679c39..3f25a13 100644 --- a/internal/db/migrations/00001_create_users_table.sql +++ b/internal/db/migrations/00001_create_users_table.sql @@ -1,7 +1,7 @@ -- +goose Up -- +goose StatementBegin CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY NOT NULL, + id TEXT PRIMARY KEY NOT NULL, email TEXT UNIQUE NOT NULL, password_hash TEXT NULL, -- Allow null for passwordless login pending_email TEXT NULL, -- Store new email when changing email diff --git a/internal/db/migrations/00002_create_tokens_table.sql b/internal/db/migrations/00002_create_tokens_table.sql index 96fbaf1..8c1003d 100644 --- a/internal/db/migrations/00002_create_tokens_table.sql +++ b/internal/db/migrations/00002_create_tokens_table.sql @@ -1,8 +1,8 @@ -- +goose Up -- +goose StatementBegin CREATE TABLE IF NOT EXISTS tokens ( - id SERIAL PRIMARY KEY NOT NULL, - user_id INTEGER NOT NULL, + id TEXT PRIMARY KEY NOT NULL, + user_id TEXT NOT NULL, type TEXT NOT NULL, token TEXT UNIQUE NOT NULL, expires_at TIMESTAMP NOT NULL, diff --git a/internal/db/migrations/00003_create_profiles_table.sql b/internal/db/migrations/00003_create_profiles_table.sql index edd5ede..390f6e6 100644 --- a/internal/db/migrations/00003_create_profiles_table.sql +++ b/internal/db/migrations/00003_create_profiles_table.sql @@ -1,8 +1,8 @@ -- +goose Up -- +goose StatementBegin CREATE TABLE IF NOT EXISTS profiles ( - id SERIAL PRIMARY KEY NOT NULL, - user_id INTEGER UNIQUE NOT NULL, + id TEXT PRIMARY KEY NOT NULL, + user_id TEXT UNIQUE NOT NULL, name TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/internal/db/migrations/00004_create_files_table.sql b/internal/db/migrations/00004_create_files_table.sql index ee97969..2ea0a01 100644 --- a/internal/db/migrations/00004_create_files_table.sql +++ b/internal/db/migrations/00004_create_files_table.sql @@ -1,8 +1,8 @@ -- +goose Up -- +goose StatementBegin CREATE TABLE IF NOT EXISTS files ( - id SERIAL PRIMARY KEY NOT NULL, - user_id INTEGER NOT NULL, + id TEXT PRIMARY KEY NOT NULL, + user_id TEXT NOT NULL, owner_type TEXT NOT NULL, owner_id TEXT NOT NULL, type TEXT NOT NULL, diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 949eca4..2d858ea 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -1,17 +1,24 @@ package handler import ( + "log/slog" "net/http" + "strings" + "time" + "git.juancwu.dev/juancwu/budgit/internal/service" "git.juancwu.dev/juancwu/budgit/internal/ui" + "git.juancwu.dev/juancwu/budgit/internal/ui/components/toast" "git.juancwu.dev/juancwu/budgit/internal/ui/pages" + "git.juancwu.dev/juancwu/budgit/internal/validation" ) type authHandler struct { + authService *service.AuthService } -func NewAuthHandler() *authHandler { - return &authHandler{} +func NewAuthHandler(authService *service.AuthService) *authHandler { + return &authHandler{authService: authService} } func (h *authHandler) AuthPage(w http.ResponseWriter, r *http.Request) { @@ -21,3 +28,62 @@ func (h *authHandler) AuthPage(w http.ResponseWriter, r *http.Request) { func (h *authHandler) PasswordPage(w http.ResponseWriter, r *http.Request) { ui.Render(w, r, pages.AuthPassword("")) } + +func (h *authHandler) SendMagicLink(w http.ResponseWriter, r *http.Request) { + email := strings.TrimSpace(r.FormValue("email")) + + if email == "" { + ui.Render(w, r, pages.Auth("Email is required")) + return + } + + err := validation.ValidateEmail(email) + if err != nil { + ui.Render(w, r, pages.Auth("Please provide a valid email address")) + return + } + + err = h.authService.SendMagicLink(email) + if err != nil { + slog.Warn("magic link send failed", "error", err, "email", email) + } + + if r.URL.Query().Get("resend") == "true" { + ui.RenderOOB(w, r, toast.Toast(toast.Props{ + Title: "Magic link sent", + Description: "Check your email for a new magic link", + Variant: toast.VariantSuccess, + Icon: true, + Dismissible: true, + Duration: 5000, + }), "beforeend:#toast-container") + return + } + + ui.Render(w, r, pages.MagicLinkSent(email)) +} + +func (h *authHandler) VerifyMagicLink(w http.ResponseWriter, r *http.Request) { + tokenString := r.PathValue("token") + + user, err := h.authService.VerifyMagicLink(tokenString) + if err != nil { + slog.Warn("magic link verification failed", "error", err, "token", tokenString) + ui.Render(w, r, pages.Auth("Invalid or expired magic link. Please try again.")) + return + } + + jwtToken, err := h.authService.GenerateJWT(user) + if err != nil { + slog.Error("failed to generate JWT", "error", err, "user_id", user.ID) + ui.Render(w, r, pages.Auth("An error occurred. Please try again.")) + return + } + + h.authService.SetJWTCookie(w, jwtToken, time.Now().Add(7*24*time.Hour)) + + // TODO: check for onboarding + + slog.Info("user logged via magic link", "user_id", user.ID, "email", user.Email) + http.Redirect(w, r, "/app/dashboard", http.StatusSeeOther) +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 83c11cb..7c50483 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -10,13 +10,58 @@ import ( // TODO: implement clearing jwt token in auth service // AuthMiddleware checks for JWT token and adds user + profile + subscription to context if valid -func AuthMiddleware(authService *service.AuthService, userService *service.UserService) func(http.Handler) http.Handler { +func AuthMiddleware(authService *service.AuthService, userService *service.UserService, profileService *service.ProfileService) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO: get auth cookie and verify value - // TODO: fetch user information from database if cookie value is valid - // TODO: add user to context if valid - next.ServeHTTP(w, r) + // Get JWT from cookie + cookie, err := r.Cookie("auth_token") + if err != nil { + // No cookie, continue without auth + next.ServeHTTP(w, r) + return + } + + // Verify token + claims, err := authService.VerifyJWT(cookie.Value) + if err != nil { + // Invalid token, clear cookie and continue + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Get user ID from claims + userID, ok := claims["user_id"].(string) + if !ok { + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Fetch user from database + user, err := userService.ByID(userID) + if err != nil { + + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Security: Remove password hash from context + user.PasswordHash = nil + + profile, err := profileService.ByUserID(userID) + if err != nil { + // Profile not found - this shouldn't happen but handle gracefully + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Add user + profile to context + ctx := ctxkeys.WithUser(r.Context(), user) + ctx = ctxkeys.WithProfile(ctx, profile) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } @@ -56,17 +101,17 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc { // Check if user has completed onboarding // Uses profile.Name as indicator (empty = incomplete onboarding) - profile := ctxkeys.Profile(r.Context()) - if profile.Name == "" && r.URL.Path != "/auth/onboarding" { - // User hasn't completed onboarding, redirect to onboarding - if r.Header.Get("HX-Request") == "true" { - w.Header().Set("HX-Redirect", "/auth/onboarding") - w.WriteHeader(http.StatusSeeOther) - return - } - http.Redirect(w, r, "/auth/onboarding", http.StatusSeeOther) - return - } + // profile := ctxkeys.Profile(r.Context()) + // if profile.Name == "" && r.URL.Path != "/auth/onboarding" { + // // User hasn't completed onboarding, redirect to onboarding + // if r.Header.Get("HX-Request") == "true" { + // w.Header().Set("HX-Redirect", "/auth/onboarding") + // w.WriteHeader(http.StatusSeeOther) + // return + // } + // http.Redirect(w, r, "/auth/onboarding", http.StatusSeeOther) + // return + // } next.ServeHTTP(w, r) } diff --git a/internal/middleware/auth.go.bak b/internal/middleware/auth.go.bak new file mode 100644 index 0000000..99505c2 --- /dev/null +++ b/internal/middleware/auth.go.bak @@ -0,0 +1,117 @@ +package middleware + +import ( + "net/http" + + "git.juancwu.dev/juancwu/budgit/internal/ctxkeys" + "git.juancwu.dev/juancwu/budgit/internal/service" +) + +// TODO: implement clearing jwt token in auth service + +// AuthMiddleware checks for JWT token and adds user + profile + subscription to context if valid +func AuthMiddleware(authService *service.AuthService, userService *service.UserService, profileService *service.ProfileService) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get JWT from cookie + cookie, err := r.Cookie("auth_token") + if err != nil { + // No cookie, continue without auth + next.ServeHTTP(w, r) + return + } + + // Verify token + claims, err := authService.VerifyJWT(cookie.Value) + if err != nil { + // Invalid token, clear cookie and continue + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Get user ID from claims + userID, ok := claims["user_id"].(int64) + if !ok { + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Fetch user from database + user, err := userService.ByID(userID) + if err != nil { + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Security: Remove password hash from context + user.PasswordHash = nil + + profile, err := profileService.ByUserID(userID) + if err != nil { + // Profile not found - this shouldn't happen but handle gracefully + authService.ClearJWTCookie(w) + next.ServeHTTP(w, r) + return + } + + // Add user + profile to context + ctx := ctxkeys.WithUser(r.Context(), user) + ctx = ctxkeys.WithProfile(ctx, profile) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequireGuest ensures request is not authenticated +func RequireGuest(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := ctxkeys.User(r.Context()) + if user != nil { + if r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", "/app/dashboard") + w.WriteHeader(http.StatusSeeOther) + return + } + http.Redirect(w, r, "/app/dashboard", http.StatusSeeOther) + return + } + next.ServeHTTP(w, r) + } +} + +// RequireAuth ensures the user is authenticated and has completed onboarding +func RequireAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := ctxkeys.User(r.Context()) + if user == nil { + // For HTMX requests, use HX-Redirect header to force full page redirect + if r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", "/auth") + w.WriteHeader(http.StatusSeeOther) + return + } + // For regular requests, use standard redirect + http.Redirect(w, r, "/auth", http.StatusSeeOther) + return + } + + // Check if user has completed onboarding + // Uses profile.Name as indicator (empty = incomplete onboarding) + // profile := ctxkeys.Profile(r.Context()) + // if profile.Name == "" && r.URL.Path != "/auth/onboarding" { + // // User hasn't completed onboarding, redirect to onboarding + // if r.Header.Get("HX-Request") == "true" { + // w.Header().Set("HX-Redirect", "/auth/onboarding") + // w.WriteHeader(http.StatusSeeOther) + // return + // } + // http.Redirect(w, r, "/auth/onboarding", http.StatusSeeOther) + // return + // } + + next.ServeHTTP(w, r) + } +} diff --git a/internal/model/file.go b/internal/model/file.go index a79c301..55641ad 100644 --- a/internal/model/file.go +++ b/internal/model/file.go @@ -9,8 +9,8 @@ const ( ) type File struct { - ID uint64 `db:"id"` - UserID uint64 `db:"user_id"` // Who owns/created this file + ID string `db:"id"` + UserID string `db:"user_id"` // Who owns/created this file OwnerType string `db:"owner_type"` // "user", "profile", etc. - the entity that owns the file OwnerID string `db:"owner_id"` // Polymorphic FK Type string `db:"type"` diff --git a/internal/model/profile.go b/internal/model/profile.go index 36e6bd5..02ea260 100644 --- a/internal/model/profile.go +++ b/internal/model/profile.go @@ -3,8 +3,8 @@ package model import "time" type Profile struct { - ID uint64 `db:"id"` - UserID uint64 `db:"user_id"` + ID string `db:"id"` + UserID string `db:"user_id"` Name string `db:"name"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` diff --git a/internal/model/token.go b/internal/model/token.go index 5a7ec18..d76d10f 100644 --- a/internal/model/token.go +++ b/internal/model/token.go @@ -5,8 +5,8 @@ import ( ) type Token struct { - ID uint64 `db:"id"` - UserID uint64 `db:"user_id"` + ID string `db:"id"` + UserID string `db:"user_id"` Type string `db:"type"` // "email_verify" or "password_reset" Token string `db:"token"` ExpiresAt time.Time `db:"expires_at"` diff --git a/internal/model/user.go b/internal/model/user.go index a077c7f..a9f39d5 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -3,7 +3,7 @@ package model import "time" type User struct { - ID uint64 `db:"id"` + ID string `db:"id"` Email string `db:"email"` // Allow null for passwordless users PasswordHash *string `db:"password_hash"` diff --git a/internal/repository/profile.go b/internal/repository/profile.go new file mode 100644 index 0000000..7e66a3a --- /dev/null +++ b/internal/repository/profile.go @@ -0,0 +1,60 @@ +package repository + +import ( + "database/sql" + "errors" + "time" + + "git.juancwu.dev/juancwu/budgit/internal/model" + "github.com/jmoiron/sqlx" +) + +var ( + ErrProfileNotFound = errors.New("profile not found") +) + +type ProfileRepository interface { + Create(profile *model.Profile) (string, error) + ByUserID(userID string) (*model.Profile, error) +} + +type profileRepository struct { + db *sqlx.DB +} + +func NewProfileRepository(db *sqlx.DB) *profileRepository { + return &profileRepository{db: db} +} + +func (r *profileRepository) Create(profile *model.Profile) (string, error) { + if profile.CreatedAt.IsZero() { + profile.CreatedAt = time.Now() + } + if profile.UpdatedAt.IsZero() { + profile.UpdatedAt = time.Now() + } + + _, err := r.db.Exec(` + INSERT INTO profiles (id, user_id, name, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5) + `, profile.ID, profile.UserID, profile.Name, profile.CreatedAt, profile.UpdatedAt) + if err != nil { + return "", err + } + + return profile.ID, nil +} + +func (r *profileRepository) ByUserID(userID string) (*model.Profile, error) { + var profile model.Profile + err := r.db.Get(&profile, `SELECT * FROM profiles WHERE user_id = $1`, userID) + + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrProfileNotFound + } + if err != nil { + return nil, err + } + + return &profile, nil +} diff --git a/internal/repository/token.go b/internal/repository/token.go new file mode 100644 index 0000000..f115da9 --- /dev/null +++ b/internal/repository/token.go @@ -0,0 +1,77 @@ +package repository + +import ( + "database/sql" + "errors" + "fmt" + "time" + + "git.juancwu.dev/juancwu/budgit/internal/model" + "github.com/jmoiron/sqlx" +) + +var ( + ErrTokenNotFound = errors.New("token not found") +) + +type TokenRepository interface { + Create(token *model.Token) (string, error) + DeleteByUserAndType(userID string, tokenType string) error + ConsumeToken(token string) (*model.Token, error) +} + +type tokenRepository struct { + db *sqlx.DB +} + +func NewTokenRepository(db *sqlx.DB) *tokenRepository { + return &tokenRepository{db: db} +} + +func (r *tokenRepository) Create(token *model.Token) (string, error) { + if token.CreatedAt.IsZero() { + token.CreatedAt = time.Now() + } + + query := ` + INSERT INTO tokens (id, user_id, type, token, expires_at, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + ` + + _, err := r.db.Exec(query, token.ID, token.UserID, token.Type, token.Token, token.ExpiresAt, token.CreatedAt) + if err != nil { + return "", fmt.Errorf("failed to create token: %w", err) + } + + return token.ID, nil +} + +func (r *tokenRepository) DeleteByUserAndType(userID string, tokenType string) error { + query := `DELETE FROM tokens WHERE user_id = $1 AND type = $2 AND used_at IS NULL` + _, err := r.db.Exec(query, userID, tokenType) + return err +} + +func (r *tokenRepository) ConsumeToken(tokenString string) (*model.Token, error) { + var token model.Token + now := time.Now() + + query := ` + UPDATE tokens + SET used_at = $1 + WHERE token = $2 + AND used_at IS NULL + AND expires_at > $3 + RETURNING * + ` + + err := r.db.Get(&token, query, now, tokenString, now) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrTokenNotFound + } + if err != nil { + return nil, err + } + + return &token, nil +} diff --git a/internal/repository/user.go b/internal/repository/user.go index 61b95b6..76a66ba 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -15,7 +15,7 @@ var ( ) type UserRepository interface { - Create(user *model.User) error + Create(user *model.User) (string, error) ByID(id string) (*model.User, error) ByEmail(email string) (*model.User, error) Update(user *model.User) error @@ -30,19 +30,19 @@ func NewUserRepository(db *sqlx.DB) UserRepository { return &userRepository{db: db} } -func (r *userRepository) Create(user *model.User) error { +func (r *userRepository) Create(user *model.User) (string, error) { query := `INSERT INTO users (id, email, password_hash, email_verified_at, created_at) VALUES ($1, $2, $3, $4, $5);` _, err := r.db.Exec(query, user.ID, user.Email, user.PasswordHash, user.EmailVerifiedAt, user.CreatedAt) if err != nil { errStr := err.Error() if strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "duplicate key value") { - return ErrDuplicateEmail + return "", ErrDuplicateEmail } - return err + return "", err } - return nil + return user.ID, nil } func (r *userRepository) ByID(id string) (*model.User, error) { @@ -58,15 +58,15 @@ func (r *userRepository) ByID(id string) (*model.User, error) { } func (r *userRepository) ByEmail(email string) (*model.User, error) { - user := &model.User{} + var user model.User query := `SELECT * FROM users WHERE email = $1;` - err := r.db.Get(user, query, email) + err := r.db.Get(&user, query, email) if err == sql.ErrNoRows { return nil, ErrUserNotFound } - return user, err + return &user, err } func (r *userRepository) Update(user *model.User) error { diff --git a/internal/routes/routes.go b/internal/routes/routes.go index e38b0d8..aa1f525 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -11,7 +11,7 @@ import ( ) func SetupRoutes(a *app.App) http.Handler { - auth := handler.NewAuthHandler() + auth := handler.NewAuthHandler(a.AuthService) home := handler.NewHomeHandler() dashboard := handler.NewDashboardHandler() @@ -26,9 +26,17 @@ func SetupRoutes(a *app.App) http.Handler { mux.Handle("GET /assets/", http.StripPrefix("/assets/", http.FileServer(http.FS(sub)))) // Auth pages + // authRateLimiter := middleware.RateLimitAuth() + mux.HandleFunc("GET /auth", middleware.RequireGuest(auth.AuthPage)) mux.HandleFunc("GET /auth/password", middleware.RequireGuest(auth.PasswordPage)) + // Token Verifications + mux.HandleFunc("GET /auth/magic-link/{token}", auth.VerifyMagicLink) + + // Auth Actions + mux.HandleFunc("POST /auth/magic-link", middleware.RequireGuest(auth.SendMagicLink)) + // ==================================================================================== // PRIVATE ROUTES // ==================================================================================== @@ -44,7 +52,7 @@ func SetupRoutes(a *app.App) http.Handler { middleware.Config(a.Cfg), middleware.RequestLogging, middleware.CSRFProtection, - middleware.AuthMiddleware(a.AuthService, a.UserService), + middleware.AuthMiddleware(a.AuthService, a.UserService, a.ProfileService), middleware.WithURLPath, ) diff --git a/internal/service/auth.go b/internal/service/auth.go index 3a73ed7..94bcf04 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -1,14 +1,22 @@ package service import ( + "crypto/rand" + "encoding/hex" "errors" + "fmt" + "log/slog" + "net/http" "strings" + "time" "git.juancwu.dev/juancwu/budgit/internal/exception" "git.juancwu.dev/juancwu/budgit/internal/model" "git.juancwu.dev/juancwu/budgit/internal/repository" + "git.juancwu.dev/juancwu/budgit/internal/validation" "github.com/alexedwards/argon2id" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" ) var ( @@ -24,12 +32,35 @@ var ( ) type AuthService struct { - userRepository repository.UserRepository + emailService *EmailService + userRepository repository.UserRepository + profileRepository repository.ProfileRepository + tokenRepository repository.TokenRepository + jwtSecret string + jwtExpiry time.Duration + tokenMagicLinkExpiry time.Duration + isProduction bool } -func NewAuthService(userRepository repository.UserRepository) *AuthService { +func NewAuthService( + emailService *EmailService, + userRepository repository.UserRepository, + profileRepository repository.ProfileRepository, + tokenRepository repository.TokenRepository, + jwtSecret string, + jwtExpiry time.Duration, + tokenMagicLinkExpiry time.Duration, + isProduction bool, +) *AuthService { return &AuthService{ - userRepository: userRepository, + emailService: emailService, + userRepository: userRepository, + profileRepository: profileRepository, + tokenRepository: tokenRepository, + jwtSecret: jwtSecret, + jwtExpiry: jwtExpiry, + tokenMagicLinkExpiry: tokenMagicLinkExpiry, + isProduction: isProduction, } } @@ -75,6 +106,188 @@ func (s *AuthService) ComparePassword(password, hash string) error { return nil } -func (s *AuthService) VerifyJWT(value string) (jwt.MapClaims, error) { - return nil, nil +func (s *AuthService) GenerateJWT(user *model.User) (string, error) { + claims := jwt.MapClaims{ + "user_id": user.ID, + "email": user.Email, + "exp": time.Now().Add(s.jwtExpiry).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + tokenString, err := token.SignedString([]byte(s.jwtSecret)) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func (s *AuthService) VerifyJWT(tokenString string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(s.jwtSecret), nil + }) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if ok && token.Valid { + return claims, nil + } + + return nil, fmt.Errorf("invalid token") +} + +func (s *AuthService) SetJWTCookie(w http.ResponseWriter, token string, expiry time.Time) { + http.SetCookie(w, &http.Cookie{ + Name: "auth_token", + Value: token, + Expires: expiry, + Path: "/", + HttpOnly: true, + Secure: s.isProduction, + SameSite: http.SameSiteLaxMode, + }) +} + +func (s *AuthService) ClearJWTCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: "auth_token", + Value: "", + Expires: time.Unix(0, 0), + Path: "/", + HttpOnly: true, + Secure: s.isProduction, + SameSite: http.SameSiteLaxMode, + }) +} + +func (s *AuthService) GenerateToken() (string, error) { + bytes := make([]byte, 32) + _, err := rand.Read(bytes) + if err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +func (s *AuthService) SendMagicLink(email string) error { + email = strings.TrimSpace(strings.ToLower(email)) + + err := validation.ValidateEmail(email) + if err != nil { + return ErrInvalidEmail + } + + user, err := s.userRepository.ByEmail(email) + if err != nil { + // User doesn't exists - create a new passwordless account + if errors.Is(err, repository.ErrUserNotFound) { + now := time.Now() + user = &model.User{ + ID: uuid.NewString(), + Email: email, + CreatedAt: now, + } + _, err := s.userRepository.Create(user) + if err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + + slog.Info("new user created with id", "id", user.ID) + + profile := &model.Profile{ + ID: uuid.NewString(), + UserID: user.ID, + Name: "", + CreatedAt: now, + UpdatedAt: now, + } + + _, err = s.profileRepository.Create(profile) + if err != nil { + return fmt.Errorf("failed to create profile: %w", err) + } + + slog.Info("new passwordless user created", "email", email, "user_id", user.ID) + } else { + // user look up unexpected error + return fmt.Errorf("failed to look up user: %w", err) + } + } + + err = s.tokenRepository.DeleteByUserAndType(user.ID, model.TokenTypeMagicLink) + if err != nil { + slog.Warn("failed to delete old magic link tokens", "error", err, "user_id", user.ID) + } + + magicToken, err := s.GenerateToken() + if err != nil { + return fmt.Errorf("failed to generate token: %w", err) + } + + token := &model.Token{ + ID: uuid.NewString(), + UserID: user.ID, + Type: model.TokenTypeMagicLink, + Token: magicToken, + ExpiresAt: time.Now().Add(s.tokenMagicLinkExpiry), + } + + _, err = s.tokenRepository.Create(token) + if err != nil { + return fmt.Errorf("failed to create token: %w", err) + } + + profile, err := s.profileRepository.ByUserID(user.ID) + name := "" + if err == nil && profile != nil { + name = profile.Name + } + + err = s.emailService.SendMagicLinkEmail(user.Email, magicToken, name) + if err != nil { + slog.Error("failed to send magic link email", "error", err, "email", user.Email) + return fmt.Errorf("failed to send email: %w", err) + } + + slog.Info("magic link sent", "email", user.Email) + return nil +} + +func (s *AuthService) VerifyMagicLink(tokenString string) (*model.User, error) { + token, err := s.tokenRepository.ConsumeToken(tokenString) + if err != nil { + return nil, fmt.Errorf("invalid or expired magic link") + } + + if token.Type != model.TokenTypeMagicLink { + return nil, fmt.Errorf("invalid token type") + } + + user, err := s.userRepository.ByID(token.UserID) + if errors.Is(err, repository.ErrUserNotFound) { + return nil, err + } + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + + if user.EmailVerifiedAt == nil { + now := time.Now() + user.EmailVerifiedAt = &now + err = s.userRepository.Update(user) + if err != nil { + slog.Warn("failed to set email verification time", "error", err, "user_id", user.ID) + } + } + + slog.Info("user authenticated via magic link", "user_id", user.ID, "email", user.Email) + + return user, nil } diff --git a/internal/service/email.go b/internal/service/email.go index 73add75..d4c0312 100644 --- a/internal/service/email.go +++ b/internal/service/email.go @@ -13,15 +13,14 @@ import ( ) type EmailParams struct { - From string - EnvelopeFrom string - To []string - Bcc []string - Cc []string - ReplyTo string - Subject string - Text string - Html string + From string + To []string + Bcc []string + Cc []string + ReplyTo string + Subject string + Text string + Html string } type EmailClient struct { @@ -47,12 +46,20 @@ func NewEmailClient(smtpHost string, smtpPort int, imapHost string, imapPort int func (nc *EmailClient) SendWithContext(ctx context.Context, params *EmailParams) (string, error) { m := mail.NewMsg() m.From(params.From) - m.EnvelopeFrom(params.EnvelopeFrom) m.To(params.To...) m.Subject(params.Subject) - m.SetBodyString(mail.TypeTextPlain, params.Text) - m.SetBodyString(mail.TypeTextHTML, params.Html) - m.ReplyTo(params.ReplyTo) + + if params.Html != "" { + m.SetBodyString(mail.TypeTextHTML, params.Html) + m.AddAlternativeString(mail.TypeTextPlain, params.Text) + } else { + m.SetBodyString(mail.TypeTextPlain, params.Text) + } + + if params.ReplyTo != "" { + m.ReplyTo(params.ReplyTo) + } + m.SetDate() m.SetMessageID() @@ -126,26 +133,20 @@ func (nc *EmailClient) connectToIMAP() (*client.Client, error) { } type EmailService struct { - client *EmailClient - fromEmail string - fromEnvelope string - supportEmail string - supportEnvelope string - isDev bool - appURL string - appName string + client *EmailClient + fromEmail string + isProd bool + appURL string + appName string } -func NewEmailService(client *EmailClient, fromEmail, fromEnvelope, supportEmail, supportEnvelope, appURL, appName string, isDev bool) *EmailService { +func NewEmailService(client *EmailClient, fromEmail, appURL, appName string, isProd bool) *EmailService { return &EmailService{ - client: client, - fromEmail: fromEmail, - fromEnvelope: fromEnvelope, - supportEmail: supportEmail, - supportEnvelope: supportEnvelope, - isDev: isDev, - appURL: appURL, - appName: appName, + client: client, + fromEmail: fromEmail, + isProd: isProd, + appURL: appURL, + appName: appName, } } @@ -153,10 +154,10 @@ func (s *EmailService) SendMagicLinkEmail(email, token, name string) error { magicURL := fmt.Sprintf("%s/auth/magic-link/%s", s.appURL, token) subject, body := magicLinkEmailTemplate(magicURL, s.appName) - if s.isDev { - slog.Info("email sent (dev mode)", "type", "magic_link", "to", email, "subject", subject, "url", magicURL) - return nil - } + // if !s.isProd { + // slog.Info("email sent (dev mode)", "type", "magic_link", "to", email, "subject", subject, "url", magicURL) + // return nil + // } params := &EmailParams{ From: s.fromEmail, diff --git a/internal/service/profile.go b/internal/service/profile.go new file mode 100644 index 0000000..95cda4d --- /dev/null +++ b/internal/service/profile.go @@ -0,0 +1,20 @@ +package service + +import ( + "git.juancwu.dev/juancwu/budgit/internal/model" + "git.juancwu.dev/juancwu/budgit/internal/repository" +) + +type ProfileService struct { + profileRepository repository.ProfileRepository +} + +func NewProfileService(profileRepository repository.ProfileRepository) *ProfileService { + return &ProfileService{ + profileRepository: profileRepository, + } +} + +func (s *ProfileService) ByUserID(userID string) (*model.Profile, error) { + return s.profileRepository.ByUserID(userID) +} diff --git a/internal/ui/pages/auth_magic_link_sent.templ b/internal/ui/pages/auth_magic_link_sent.templ new file mode 100644 index 0000000..d8bda9b --- /dev/null +++ b/internal/ui/pages/auth_magic_link_sent.templ @@ -0,0 +1,67 @@ +package pages + +import ( + "git.juancwu.dev/juancwu/budgit/internal/ui/layouts" + "git.juancwu.dev/juancwu/budgit/internal/ctxkeys" + "git.juancwu.dev/juancwu/budgit/internal/ui/components/icon" + "git.juancwu.dev/juancwu/budgit/internal/ui/components/csrf" + "git.juancwu.dev/juancwu/budgit/internal/ui/components/button" +) + +templ MagicLinkSent(email string) { + @layouts.Auth(layouts.SEOProps{ + Title: "Check Your Email", + Description: "Magic link sent to your email", + Path: ctxkeys.URLPath(ctx), + }) { +
We've sent a magic link to
+{ email }
+Click the link in your email to sign in instantly.
+The link will expire in 10 minutes and can only be used once.
++ Didn't receive it? Check your spam folder or + + contact support + +
+ +