Skip to content
This repository has been archived by the owner on Aug 16, 2022. It is now read-only.

Commit

Permalink
fix: auth server bugs and auth client bugs (#125)
Browse files Browse the repository at this point in the history
Co-authored-by: maherhamoui6 <hammimo022@gmail.com>
Co-authored-by: rot1024 <aayhrot@gmail.com>
  • Loading branch information
3 people authored Mar 16, 2022
1 parent 82cf28c commit ce23099
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 55 deletions.
1 change: 1 addition & 0 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func initEcho(ctx context.Context, cfg *ServerConfig) *echo.Echo {
api.GET("/published/:name", PublishedMetadata())
api.GET("/published_data/:name", PublishedData())

// authenticated endpoints
privateApi := api.Group("", AuthRequiredMiddleware())
graphqlAPI(e, privateApi, cfg)
privateAPI(e, privateApi, cfg.Repos)
Expand Down
34 changes: 17 additions & 17 deletions internal/app/auth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
Expand All @@ -15,12 +14,14 @@ import (
"github.com/labstack/echo/v4"
"github.com/reearth/reearth-backend/internal/usecase/interactor"
"github.com/reearth/reearth-backend/internal/usecase/interfaces"
"github.com/reearth/reearth-backend/pkg/log"
)

var (
const (
loginEndpoint = "api/login"
logoutEndpoint = "api/logout"
jwksEndpoint = ".well-known/jwks.json"
authProvider = "reearth"
)

func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *ServerConfig) {
Expand Down Expand Up @@ -65,7 +66,7 @@ func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *Server
userUsecase.GetUserBySubject,
)
if err != nil {
e.Logger.Fatal(err)
log.Fatalf("auth: init failed: %s\n", err)
}

handler, err := op.NewOpenIDProvider(
Expand All @@ -78,13 +79,13 @@ func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *Server
op.WithCustomKeysEndpoint(op.NewEndpoint(jwksEndpoint)),
)
if err != nil {
e.Logger.Fatal(fmt.Errorf("auth: init failed: %w", err))
log.Fatalf("auth: init failed: %s\n", err)
}

router := handler.HttpHandler().(*mux.Router)

if err := router.Walk(muxToEchoMapper(r)); err != nil {
e.Logger.Fatal(fmt.Errorf("auth: walk failed: %w", err))
log.Fatalf("auth: walk failed: %s\n", err)
}

// Actual login endpoint
Expand Down Expand Up @@ -178,23 +179,22 @@ type loginForm struct {

func login(ctx context.Context, cfg *ServerConfig, storage op.Storage, userUsecase interfaces.User) func(ctx echo.Context) error {
return func(ec echo.Context) error {

request := new(loginForm)
err := ec.Bind(request)
if err != nil {
ec.Logger().Error("filed to parse login request")
return err
log.Errorln("auth: filed to parse login request")
return ec.Redirect(http.StatusFound, redirectURL(ec.Request().Referer(), !cfg.Debug, "", "Bad request!"))
}

authRequest, err := storage.AuthRequestByID(ctx, request.AuthRequestID)
if err != nil {
ec.Logger().Error("filed to parse login request")
return err
log.Errorf("auth: filed to parse login request: %s\n", err)
return ec.Redirect(http.StatusFound, redirectURL(ec.Request().Referer(), !cfg.Debug, "", "Bad request!"))
}

if len(request.Email) == 0 || len(request.Password) == 0 {
ec.Logger().Error("credentials are not provided")
return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "invalid login"))
log.Errorln("auth: one of credentials are not provided")
return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "Bad request!"))
}

// check user credentials from db
Expand All @@ -203,15 +203,15 @@ func login(ctx context.Context, cfg *ServerConfig, storage op.Storage, userUseca
Password: request.Password,
})
if err != nil {
ec.Logger().Error("wrong credentials!")
return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "invalid login"))
log.Errorf("auth: wrong credentials: %s\n", err)
return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "Login failed; Invalid user ID or password."))
}

// Complete the auth request && set the subject
err = storage.(*interactor.AuthStorage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider("reearth").Sub)
err = storage.(*interactor.AuthStorage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider(authProvider).Sub)
if err != nil {
ec.Logger().Error("failed to complete the auth request !")
return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "invalid login"))
log.Errorf("auth: failed to complete the auth request: %s\n", err)
return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "Bad request!"))
}

return ec.Redirect(http.StatusFound, "/authorize/callback?id="+request.AuthRequestID)
Expand Down
2 changes: 1 addition & 1 deletion internal/app/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func jwtEchoMiddleware(cfg *ServerConfig) echo.MiddlewareFunc {
log.Fatalf("failed to set up the validator: %v", err)
}

middleware := jwtmiddleware.New(jwtValidator.ValidateToken)
middleware := jwtmiddleware.New(jwtValidator.ValidateToken, jwtmiddleware.WithCredentialsOptional(true))

return echo.WrapMiddleware(middleware.CheckJWT)
}
Expand Down
16 changes: 10 additions & 6 deletions internal/app/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,22 @@ func PasswordReset() echo.HandlerFunc {
uc := adapter.Usecases(c.Request().Context())
controller := http1.NewUserController(uc.User)

if len(inp.Email) > 0 {
isStartingNewRequest := len(inp.Email) > 0 && len(inp.Token) == 0 && len(inp.Password) == 0
isSettingNewPassword := len(inp.Email) > 0 && len(inp.Token) > 0 && len(inp.Password) > 0

if isStartingNewRequest {
if err := controller.StartPasswordReset(c.Request().Context(), inp); err != nil {
return err
c.Logger().Error("an attempt to start reset password failed. internal error: %w", err)
}
return c.JSON(http.StatusOK, true)
return c.JSON(http.StatusOK, echo.Map{"message": "If that email address is in our database, we will send you an email to reset your password."})
}

if len(inp.Token) > 0 && len(inp.Password) > 0 {
if isSettingNewPassword {
if err := controller.PasswordReset(c.Request().Context(), inp); err != nil {
return err
c.Logger().Error("an attempt to Set password failed. internal error: %w", err)
return c.JSON(http.StatusBadRequest, echo.Map{"message": "Bad set password request"})
}
return c.JSON(http.StatusOK, true)
return c.JSON(http.StatusOK, echo.Map{"message": "Password is updated successfully"})
}

return &echo.HTTPError{Code: http.StatusBadRequest, Message: "Bad reset password request"}
Expand Down
17 changes: 17 additions & 0 deletions internal/infrastructure/memory/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ func (r *User) FindByEmail(ctx context.Context, email string) (*user.User, error
return nil, rerror.ErrNotFound
}

func (r *User) FindByName(ctx context.Context, name string) (*user.User, error) {
r.lock.Lock()
defer r.lock.Unlock()

if name == "" {
return nil, rerror.ErrInvalidParams
}

for _, u := range r.data {
if u.Name() == name {
return &u, nil
}
}

return nil, rerror.ErrNotFound
}

func (r *User) FindByNameOrEmail(ctx context.Context, nameOrEmail string) (*user.User, error) {
r.lock.Lock()
defer r.lock.Unlock()
Expand Down
37 changes: 24 additions & 13 deletions internal/infrastructure/mongo/mongodoc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,24 +304,35 @@ func (c *Client) Paginate(ctx context.Context, col string, filter interface{}, p
}

func (c *Client) CreateIndex(ctx context.Context, col string, keys []string) []string {
return c.CreateUniqueIndex(ctx, col, keys, []string{})
}

func (c *Client) CreateUniqueIndex(ctx context.Context, col string, keys, uniqueKeys []string) []string {
coll := c.Collection(col)
indexedKeys := indexes(ctx, coll)

newIndexes := []mongo.IndexModel{}
// store unique keys as map to check them in an efficient way
ukm := map[string]struct{}{}
for _, k := range append([]string{"id"}, uniqueKeys...) {
ukm[k] = struct{}{}
}

var newIndexes []mongo.IndexModel
for _, k := range append([]string{"id"}, keys...) {
if _, ok := indexedKeys[k]; !ok {
indexBg := true
unique := k == "id"
newIndexes = append(newIndexes, mongo.IndexModel{
Keys: map[string]int{
k: 1,
},
Options: &options.IndexOptions{
Background: &indexBg,
Unique: &unique,
},
})
if _, ok := indexedKeys[k]; ok {
continue
}
indexBg := true
_, isUnique := ukm[k]
newIndexes = append(newIndexes, mongo.IndexModel{
Keys: map[string]int{
k: 1,
},
Options: &options.IndexOptions{
Background: &indexBg,
Unique: &isUnique,
},
})
}

if len(newIndexes) > 0 {
Expand Down
4 changes: 4 additions & 0 deletions internal/infrastructure/mongo/mongodoc/clientcol.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ func (c *ClientCollection) RemoveAll(ctx context.Context, f interface{}) error {
func (c *ClientCollection) CreateIndex(ctx context.Context, keys []string) []string {
return c.Client.CreateIndex(ctx, c.CollectionName, keys)
}

func (c *ClientCollection) CreateUniqueIndex(ctx context.Context, keys, uniqueKeys []string) []string {
return c.Client.CreateUniqueIndex(ctx, c.CollectionName, keys, uniqueKeys)
}
7 changes: 6 additions & 1 deletion internal/infrastructure/mongo/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func NewUser(client *mongodoc.Client) repo.User {
}

func (r *userRepo) init() {
i := r.client.CreateIndex(context.Background(), []string{"email", "auth0sublist"})
i := r.client.CreateUniqueIndex(context.Background(), []string{"email", "name", "auth0sublist"}, []string{"name"})
if len(i) > 0 {
log.Infof("mongo: %s: index created: %s", "user", i)
}
Expand Down Expand Up @@ -65,6 +65,11 @@ func (r *userRepo) FindByEmail(ctx context.Context, email string) (*user.User, e
return r.findOne(ctx, filter)
}

func (r *userRepo) FindByName(ctx context.Context, name string) (*user.User, error) {
filter := bson.D{{Key: "name", Value: name}}
return r.findOne(ctx, filter)
}

func (r *userRepo) FindByNameOrEmail(ctx context.Context, nameOrEmail string) (*user.User, error) {
filter := bson.D{{Key: "$or", Value: []bson.D{
{{Key: "email", Value: nameOrEmail}},
Expand Down
64 changes: 47 additions & 17 deletions internal/usecase/interactor/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (i *User) Signup(ctx context.Context, inp interfaces.SignupParam) (u *user.
return
}
} else if isAuth {
if *inp.Name == "" {
if _, err := mail.ParseAddress(*inp.Name); err == nil || *inp.Name == "" {
return nil, nil, interfaces.ErrSignupInvalidName
}
if _, err := mail.ParseAddress(*inp.Email); err != nil {
Expand Down Expand Up @@ -228,28 +228,38 @@ func (i *User) Signup(ctx context.Context, inp interfaces.SignupParam) (u *user.

func (i *User) reearthSignup(ctx context.Context, inp interfaces.SignupParam) (string, string, *user.User, *user.Team, error) {
// Check if user email already exists
existed, err := i.userRepo.FindByEmail(ctx, *inp.Email)
existedByEmail, err := i.userRepo.FindByEmail(ctx, *inp.Email)
if err != nil && !errors.Is(err, rerror.ErrNotFound) {
return "", "", nil, nil, err
}

if existed != nil {
if existed.Verification().IsVerified() {
return "", "", nil, nil, errors.New("existed user email")
} else {
// if user exists but not verified -> create a new verification
if err := i.CreateVerification(ctx, *inp.Email); err != nil {
return "", "", nil, nil, err
} else {
team, err := i.teamRepo.FindByID(ctx, existed.Team())
if err != nil && !errors.Is(err, rerror.ErrNotFound) {
return "", "", nil, nil, err
}
return "", "", existed, team, nil
}
if existedByEmail != nil {
if existedByEmail.Verification() != nil && existedByEmail.Verification().IsVerified() {
return "", "", nil, nil, errors.New("existed email")
}

// if user exists but not verified -> create a new verification
if err := i.CreateVerification(ctx, *inp.Email); err != nil {
return "", "", nil, nil, err
}

team, err := i.teamRepo.FindByID(ctx, existedByEmail.Team())
if err != nil && !errors.Is(err, rerror.ErrNotFound) {
return "", "", nil, nil, err
}
return "", "", existedByEmail, team, nil
}

existedByName, err := i.userRepo.FindByName(ctx, *inp.Name)
if err != nil && !errors.Is(err, rerror.ErrNotFound) {
return "", "", nil, nil, err
}

if existedByName != nil {
return "", "", nil, nil, errors.New("taken username")
}

// !existedByName && !existedByEmail
return *inp.Name, *inp.Email, nil, nil, nil
}

Expand Down Expand Up @@ -305,6 +315,9 @@ func (i *User) GetUserByCredentials(ctx context.Context, inp interfaces.GetUserB
if !matched {
return nil, interfaces.ErrSignupInvalidPassword
}
if u.Verification() == nil || !u.Verification().IsVerified() {
return nil, interfaces.ErrNotVerifiedUser
}
return u, nil
}

Expand Down Expand Up @@ -430,7 +443,15 @@ func (i *User) UpdateMe(ctx context.Context, p interfaces.UpdateMeParam, operato
return nil, err
}

if p.Name != nil {
if p.Name != nil && *p.Name != u.Name() {
// username should not be a valid mail
if _, err := mail.ParseAddress(*p.Name); err == nil {
return nil, interfaces.ErrSignupInvalidName
}
// make sure the username is not exists
if userByName, _ := i.userRepo.FindByName(ctx, *p.Name); userByName != nil {
return nil, interfaces.ErrSignupInvalidName
}
oldName := u.Name()
u.UpdateName(*p.Name)

Expand All @@ -456,9 +477,18 @@ func (i *User) UpdateMe(ctx context.Context, p interfaces.UpdateMeParam, operato
u.UpdateTheme(*p.Theme)
}

if p.Password != nil && u.HasAuthProvider("reearth") {
if err := u.SetPassword(*p.Password); err != nil {
return nil, err
}
}

// Update Auth0 users
if p.Name != nil || p.Email != nil || p.Password != nil {
for _, a := range u.Auths() {
if a.Provider != "auth0" {
continue
}
if _, err := i.authenticator.UpdateUser(gateway.AuthenticatorUpdateUserParam{
ID: a.Sub,
Name: p.Name,
Expand Down
1 change: 1 addition & 0 deletions internal/usecase/interfaces/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var (
ErrSignupInvalidSecret = errors.New("invalid secret")
ErrSignupInvalidName = errors.New("invalid name")
ErrInvalidUserEmail = errors.New("invalid email")
ErrNotVerifiedUser = errors.New("not verified user")
ErrSignupInvalidPassword = errors.New("invalid password")
)

Expand Down
1 change: 1 addition & 0 deletions internal/usecase/repo/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type User interface {
FindByID(context.Context, id.UserID) (*user.User, error)
FindByAuth0Sub(context.Context, string) (*user.User, error)
FindByEmail(context.Context, string) (*user.User, error)
FindByName(context.Context, string) (*user.User, error)
FindByNameOrEmail(context.Context, string) (*user.User, error)
FindByVerification(context.Context, string) (*user.User, error)
FindByPasswordResetRequest(context.Context, string) (*user.User, error)
Expand Down
12 changes: 12 additions & 0 deletions pkg/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ func (u *User) ContainAuth(a Auth) bool {
return false
}

func (u *User) HasAuthProvider(p string) bool {
if u == nil {
return false
}
for _, b := range u.auths {
if b.Provider == p {
return true
}
}
return false
}

func (u *User) AddAuth(a Auth) bool {
if u == nil {
return false
Expand Down
Loading

0 comments on commit ce23099

Please sign in to comment.