diff --git a/store/queries.go b/store/queries.go index 492e4cd..8423438 100644 --- a/store/queries.go +++ b/store/queries.go @@ -1,8 +1,9 @@ package store type Queries struct { - CreateUser string - GetUserByID string + CreateUser string + GetUserByID string + GetUserByEmail string } var CanonicalQueries Queries = Queries{ @@ -23,11 +24,21 @@ var CanonicalQueries Queries = Queries{ created_at, updated_at FROM pase_users WHERE id = ?;`, + GetUserByEmail: ` + SELECT + id, email, email_verified_at, + username, username_normalized, display_name, profile_image_url, + status, status_reason, status_changed_at, status_expires_at, + failed_login_count, last_failed_login_at, + created_at, updated_at + FROM pase_users + WHERE email = ?;`, } func (q Queries) Rebind(d Dialect) Queries { return Queries{ - CreateUser: d.Rebind(q.CreateUser), - GetUserByID: d.Rebind(q.GetUserByID), + CreateUser: d.Rebind(q.CreateUser), + GetUserByID: d.Rebind(q.GetUserByID), + GetUserByEmail: d.Rebind(q.GetUserByEmail), } } diff --git a/store/sqlite/store.go b/store/sqlite/store.go index 0a586da..8349c2a 100644 --- a/store/sqlite/store.go +++ b/store/sqlite/store.go @@ -92,3 +92,31 @@ func (s *Store) GetUserByID(ctx context.Context, id string) (*store.User, error) } return &u, nil } + +func (s *Store) GetUserByEmail(ctx context.Context, email string) (*store.User, error) { + row := s.db.QueryRowContext(ctx, s.q.GetUserByEmail, email) + u, err := s.scanUser(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("pase/sqlite: get user by email: %w", store.ErrUserNotFound) + } + return nil, fmt.Errorf("pase/sqlite: get user by email: %w", err) + } + return u, nil +} + +func (s *Store) scanUser(row *sql.Row) (*store.User, error) { + var u store.User + err := row.Scan( + &u.ID, &u.Email, &u.EmailVerifiedAt, + &u.Username, &u.UsernameNormalized, &u.DisplayName, &u.ProfileImageURL, + &u.Status, &u.StatusReason, + &u.StatusChangedAt, &u.StatusExpiresAt, + &u.FailedLoginCount, &u.LastFailedLoginAt, + &u.CreatedAt, &u.UpdatedAt, + ) + if err != nil { + return nil, err + } + return &u, nil +} diff --git a/store/storetest/storetest.go b/store/storetest/storetest.go index 5ac92df..2f02a69 100644 --- a/store/storetest/storetest.go +++ b/store/storetest/storetest.go @@ -28,6 +28,7 @@ import ( type SuiteStore interface { CreateUser(ctx context.Context, u *store.User) error GetUserByID(ctx context.Context, id string) (*store.User, error) + GetUserByEmail(ctx context.Context, email string) (*store.User, error) } // Factory returns a fresh, isolated Store. Each call to the factory must @@ -53,6 +54,7 @@ func RunSuite(t *testing.T, newStore Factory) { {"GetUserByID_notFound", testGetUserByIDNotFound}, {"CreateUser_duplicateEmail", testCreateUserDuplicateEmail}, {"CreateUser_duplicateUsernameNormalized", testCreateUserDuplicateUsernameNormalized}, + {"CreateUser_GetUserByEmail_roundTrip", testCreateUserGetUserByEmailRoundTrip}, } for _, tc := range cases { @@ -138,6 +140,22 @@ func testCreateUserDuplicateUsernameNormalized(t *testing.T, s SuiteStore) { } } +func testCreateUserGetUserByEmailRoundTrip(t *testing.T, s SuiteStore) { + ctx := context.Background() + + want := FixedUser() + if err := s.CreateUser(ctx, want); err != nil { + t.Fatalf("CreateUser: %v", err) + } + + got, err := s.GetUserByEmail(ctx, want.Email) + if err != nil { + t.Fatalf("GetUserByEmail: %v", err) + } + + AssertUserEqual(t, got, want) +} + // --------------------------------------------------------------------------- // Fixtures and helpers. Exported so dialect-specific tests can reuse them // for one-off cases that don't fit into the shared suite.