diff --git a/go.mod b/go.mod index 3d36899e..123667ce 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/goccy/go-yaml v1.9.2 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f + github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 github.com/iancoleman/strcase v0.1.3 github.com/idubinskiy/schematyper v0.0.0-20190118213059-f71b40dac30d @@ -39,6 +40,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo v0.0.0-20200707171851-ae0d272a2deb go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver v0.7.0 go.opentelemetry.io/otel v0.7.0 + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/text v0.3.7 golang.org/x/tools v0.1.5 google.golang.org/api v0.51.0 @@ -66,7 +68,6 @@ require ( github.com/golang/snappy v0.0.3 // indirect github.com/google/go-cmp v0.5.6 // indirect github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/gax-go/v2 v2.0.5 // indirect github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/schema v1.2.0 // indirect @@ -95,13 +96,12 @@ require ( github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect go.opencensus.io v0.23.0 // indirect go.uber.org/atomic v1.7.0 // indirect - golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c // indirect golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect golang.org/x/mod v0.5.0 // indirect golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420 // indirect golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect + golang.org/x/sys v0.0.0-20211015200801-69063c4bb744 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index e0ece2fb..45fd5eea 100644 --- a/go.sum +++ b/go.sum @@ -559,8 +559,8 @@ golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c h1:9HhBz5L/UjnK9XLtiZhYAdue5BVKep3PMmS2LuPDt8k= -golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -722,11 +722,11 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw= -golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/sys v0.0.0-20211015200801-69063c4bb744 h1:KzbpndAYEM+4oHRp9JmB2ewj0NHHxO3Z0g7Gus2O1kk= +golang.org/x/sys v0.0.0-20211015200801-69063c4bb744/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/adapter/gql/generated.go b/internal/adapter/gql/generated.go index 3d368029..f8b180ff 100644 --- a/internal/adapter/gql/generated.go +++ b/internal/adapter/gql/generated.go @@ -6779,13 +6779,13 @@ type RemoveAssetPayload { assetId: ID! } -type SignupPayload { +type UpdateMePayload { user: User! - team: Team! } -type UpdateMePayload { +type SignupPayload { user: User! + team: Team! } type DeleteMePayload { diff --git a/internal/adapter/gql/resolver_mutation_user.go b/internal/adapter/gql/resolver_mutation_user.go index 7f1930a1..f276db5f 100644 --- a/internal/adapter/gql/resolver_mutation_user.go +++ b/internal/adapter/gql/resolver_mutation_user.go @@ -16,14 +16,14 @@ func (r *mutationResolver) Signup(ctx context.Context, input gqlmodel.SignupInpu if input.Secret != nil { secret = *input.Secret } - + sub := getSub(ctx) u, team, err := r.usecases.User.Signup(ctx, interfaces.SignupParam{ - Sub: getSub(ctx), + Sub: &sub, Lang: input.Lang, Theme: gqlmodel.ToTheme(input.Theme), UserID: id.UserIDFromRefID(input.UserID), TeamID: id.TeamIDFromRefID(input.TeamID), - Secret: secret, + Secret: &secret, }) if err != nil { return nil, err diff --git a/internal/adapter/http/user.go b/internal/adapter/http/user.go index 6b4624d6..827bef84 100644 --- a/internal/adapter/http/user.go +++ b/internal/adapter/http/user.go @@ -17,6 +17,22 @@ func NewUserController(usecase interfaces.User) *UserController { } } +type PasswordResetInput struct { + Email string `json:"email"` + Token string `json:"token"` + Password string `json:"password"` +} + +type SignupInput struct { + Sub *string `json:"sub"` + Secret *string `json:"secret"` + UserID *id.UserID `json:"userId"` + TeamID *id.TeamID `json:"teamId"` + Name *string `json:"username"` + Email *string `json:"email"` + Password *string `json:"password"` +} + type CreateVerificationInput struct { Email string `json:"email"` } @@ -33,24 +49,27 @@ type CreateUserInput struct { TeamID *id.TeamID `json:"teamId"` } -type CreateUserOutput struct { +type SignupOutput struct { ID string `json:"id"` Name string `json:"name"` Email string `json:"email"` } -func (c *UserController) CreateUser(ctx context.Context, input CreateUserInput) (interface{}, error) { +func (c *UserController) Signup(ctx context.Context, input SignupInput) (interface{}, error) { u, _, err := c.usecase.Signup(ctx, interfaces.SignupParam{ - Sub: input.Sub, - Secret: input.Secret, - UserID: input.UserID, - TeamID: input.TeamID, + Sub: input.Sub, + Secret: input.Secret, + UserID: input.UserID, + TeamID: input.TeamID, + Name: input.Name, + Email: input.Email, + Password: input.Password, }) if err != nil { return nil, err } - return CreateUserOutput{ + return SignupOutput{ ID: u.ID().String(), Name: u.Name(), Email: u.Email(), @@ -74,3 +93,18 @@ func (c *UserController) VerifyUser(ctx context.Context, code string) (interface Verified: u.Verification().IsVerified(), }, nil } + +func (c *UserController) StartPasswordReset(ctx context.Context, input PasswordResetInput) error { + err := c.usecase.StartPasswordReset(ctx, input.Email) + if err != nil { + return err + } + + // TODO: send password reset link via email + + return nil +} + +func (c *UserController) PasswordReset(ctx context.Context, input PasswordResetInput) error { + return c.usecase.PasswordReset(ctx, input.Password, input.Token) +} diff --git a/internal/app/public.go b/internal/app/public.go index 56f060aa..f122818f 100644 --- a/internal/app/public.go +++ b/internal/app/public.go @@ -26,12 +26,12 @@ func publicAPI( }) r.POST("/signup", func(c echo.Context) error { - var inp http1.CreateUserInput + var inp http1.SignupInput if err := c.Bind(&inp); err != nil { return &echo.HTTPError{Code: http.StatusBadRequest, Message: fmt.Errorf("failed to parse request body: %w", err)} } - output, err := controller.CreateUser(c.Request().Context(), inp) + output, err := controller.Signup(c.Request().Context(), inp) if err != nil { return err } @@ -39,6 +39,29 @@ func publicAPI( return c.JSON(http.StatusOK, output) }) + r.POST("/password-reset", func(c echo.Context) error { + var inp http1.PasswordResetInput + if err := c.Bind(&inp); err != nil { + return err + } + + if len(inp.Email) > 0 { + if err := controller.StartPasswordReset(c.Request().Context(), inp); err != nil { + return err + } + return c.JSON(http.StatusOK, true) + } + + if len(inp.Token) > 0 && len(inp.Password) > 0 { + if err := controller.PasswordReset(c.Request().Context(), inp); err != nil { + return err + } + return c.JSON(http.StatusOK, true) + } + + return &echo.HTTPError{Code: http.StatusBadRequest, Message: "Bad reset password request"} + }) + r.POST("/signup/verify", func(c echo.Context) error { var inp http1.CreateVerificationInput if err := c.Bind(&inp); err != nil { diff --git a/internal/infrastructure/memory/user.go b/internal/infrastructure/memory/user.go index 8444b498..4024b1ad 100644 --- a/internal/infrastructure/memory/user.go +++ b/internal/infrastructure/memory/user.go @@ -72,6 +72,24 @@ func (r *User) FindByAuth0Sub(ctx context.Context, auth0sub string) (*user.User, return nil, rerror.ErrNotFound } +func (r *User) FindByPasswordResetRequest(ctx context.Context, token string) (*user.User, error) { + r.lock.Lock() + defer r.lock.Unlock() + + if token == "" { + return nil, rerror.ErrInvalidParams + } + + for _, u := range r.data { + pwdReq := u.PasswordReset() + if pwdReq != nil && pwdReq.Token == token { + return &u, nil + } + } + + return nil, rerror.ErrNotFound +} + func (r *User) FindByEmail(ctx context.Context, email string) (*user.User, error) { r.lock.Lock() defer r.lock.Unlock() diff --git a/internal/infrastructure/mongo/mongodoc/user.go b/internal/infrastructure/mongo/mongodoc/user.go index f44ddfe0..2f52da3e 100644 --- a/internal/infrastructure/mongo/mongodoc/user.go +++ b/internal/infrastructure/mongo/mongodoc/user.go @@ -10,16 +10,23 @@ import ( user1 "github.com/reearth/reearth-backend/pkg/user" ) +type PasswordResetDocument struct { + Token string + CreatedAt time.Time +} + type UserDocument struct { - ID string - Name string - Email string - Auth0Sub string - Auth0SubList []string - Team string - Lang string - Theme string - Verification *UserVerificationDoc + ID string + Name string + Email string + Auth0Sub string + Auth0SubList []string + Team string + Lang string + Theme string + Password []byte + PasswordReset *PasswordResetDocument + Verification *UserVerificationDoc } type UserVerificationDoc struct { @@ -64,16 +71,27 @@ func NewUser(user *user1.User) (*UserDocument, string) { Verified: user.Verification().IsVerified(), } } + pwdReset := user.PasswordReset() + + var pwdResetDoc *PasswordResetDocument + if pwdReset != nil { + pwdResetDoc = &PasswordResetDocument{ + Token: pwdReset.Token, + CreatedAt: pwdReset.CreatedAt, + } + } return &UserDocument{ - ID: id, - Name: user.Name(), - Email: user.Email(), - Auth0SubList: authsdoc, - Team: user.Team().String(), - Lang: user.Lang().String(), - Theme: string(user.Theme()), - Verification: v, + ID: id, + Name: user.Name(), + Email: user.Email(), + Auth0SubList: authsdoc, + Team: user.Team().String(), + Lang: user.Lang().String(), + Theme: string(user.Theme()), + Verification: v, + Password: user.Password(), + PasswordReset: pwdResetDoc, }, id } @@ -98,7 +116,7 @@ func (d *UserDocument) Model() (*user1.User, error) { v = user.VerificationFrom(d.Verification.Code, d.Verification.Expiration, d.Verification.Verified) } - user, err := user1.New(). + u, err := user1.New(). ID(uid). Name(d.Name). Email(d.Email). @@ -106,10 +124,23 @@ func (d *UserDocument) Model() (*user1.User, error) { Team(tid). LangFrom(d.Lang). Verification(v). + Password(d.Password). + PasswordReset(d.PasswordReset.Model()). Theme(user.Theme(d.Theme)). Build() + if err != nil { return nil, err } - return user, nil + return u, nil +} + +func (d *PasswordResetDocument) Model() *user1.PasswordReset { + if d == nil { + return nil + } + return &user1.PasswordReset{ + Token: d.Token, + CreatedAt: d.CreatedAt, + } } diff --git a/internal/infrastructure/mongo/user.go b/internal/infrastructure/mongo/user.go index 945685f4..60c73d9c 100644 --- a/internal/infrastructure/mongo/user.go +++ b/internal/infrastructure/mongo/user.go @@ -78,6 +78,13 @@ func (r *userRepo) FindByVerification(ctx context.Context, code string) (*user.U return r.findOne(ctx, filter) } +func (r *userRepo) FindByPasswordResetRequest(ctx context.Context, pwdResetToken string) (*user.User, error) { + filter := bson.D{ + {Key: "passwordreset.token", Value: pwdResetToken}, + } + return r.findOne(ctx, filter) +} + func (r *userRepo) Save(ctx context.Context, user *user.User) error { doc, id := mongodoc.NewUser(user) return r.client.SaveOne(ctx, id, doc) diff --git a/internal/usecase/interactor/emails/password_reset_html.tmpl b/internal/usecase/interactor/emails/password_reset_html.tmpl new file mode 100644 index 00000000..70d038e6 --- /dev/null +++ b/internal/usecase/interactor/emails/password_reset_html.tmpl @@ -0,0 +1,436 @@ + + + + + + + Re:Earth reset password + + + + + + + + + + + + + \ No newline at end of file diff --git a/internal/usecase/interactor/emails/password_reset_text.tmpl b/internal/usecase/interactor/emails/password_reset_text.tmpl new file mode 100644 index 00000000..03d4fd2b --- /dev/null +++ b/internal/usecase/interactor/emails/password_reset_text.tmpl @@ -0,0 +1,7 @@ +ReEarth +You have submitted a password change request! If it was you, please confirm the password change. +To reset your password click on the following link: + +{{ . }} + +If you are having any issues with your account, please don't hesitate to contact us by replying to this mail. Thank you! \ No newline at end of file diff --git a/internal/usecase/interactor/user.go b/internal/usecase/interactor/user.go index 73fb2a56..d6f235c1 100644 --- a/internal/usecase/interactor/user.go +++ b/internal/usecase/interactor/user.go @@ -1,14 +1,20 @@ package interactor import ( + "bytes" "context" + _ "embed" "errors" + htmlTmpl "html/template" + "net/mail" + textTmpl "text/template" "github.com/reearth/reearth-backend/internal/usecase" "github.com/reearth/reearth-backend/internal/usecase/gateway" "github.com/reearth/reearth-backend/internal/usecase/interfaces" "github.com/reearth/reearth-backend/internal/usecase/repo" "github.com/reearth/reearth-backend/pkg/id" + "github.com/reearth/reearth-backend/pkg/log" "github.com/reearth/reearth-backend/pkg/project" "github.com/reearth/reearth-backend/pkg/rerror" "github.com/reearth/reearth-backend/pkg/user" @@ -32,6 +38,28 @@ type User struct { signupSecret string } +var ( + //go:embed emails/password_reset_html.tmpl + passwordResetHTMLTMPLStr string + //go:embed emails/password_reset_text.tmpl + passwordResetTextTMPLStr string + + passwordResetTextTMPL *textTmpl.Template + passwordResetHTMLTMPL *htmlTmpl.Template +) + +func init() { + var err error + passwordResetTextTMPL, err = textTmpl.New("passwordReset").Parse(passwordResetTextTMPLStr) + if err != nil { + log.Panicf("password reset email template parse error: %s\n", err) + } + passwordResetHTMLTMPL, err = htmlTmpl.New("passwordReset").Parse(passwordResetHTMLTMPLStr) + if err != nil { + log.Panicf("password reset email template parse error: %s\n", err) + } +} + func NewUser(r *repo.Container, g *gateway.Container, signupSecret string) interfaces.User { return &User{ userRepo: r.User, @@ -79,43 +107,103 @@ func (i *User) Fetch(ctx context.Context, ids []id.UserID, operator *usecase.Ope } func (i *User) Signup(ctx context.Context, inp interfaces.SignupParam) (u *user.User, _ *user.Team, err error) { - if i.signupSecret != "" && inp.Secret != i.signupSecret { - return nil, nil, interfaces.ErrSignupInvalidSecret - } + var team *user.Team + var tx repo.Tx + var email, name string + var auth *user.Auth + + if inp.Secret != nil && inp.Sub != nil { + // Auth0 + if i.signupSecret != "" && *inp.Secret != i.signupSecret { + return nil, nil, interfaces.ErrSignupInvalidSecret + } - if len(inp.Sub) == 0 { - return nil, nil, errors.New("sub is required") - } + if len(*inp.Sub) == 0 { + return nil, nil, errors.New("sub is required") + } - tx, err := i.transaction.Begin() - if err != nil { - return - } - defer func() { - if err2 := tx.End(ctx); err == nil && err2 != nil { - err = err2 + tx, err = i.transaction.Begin() + if err != nil { + return } - }() + defer func() { + if err2 := tx.End(ctx); err == nil && err2 != nil { + err = err2 + } + }() - // Check if user and team already exists - existed, err := i.userRepo.FindByAuth0Sub(ctx, inp.Sub) - if err != nil && !errors.Is(err, rerror.ErrNotFound) { - return nil, nil, err - } - if existed != nil { - return nil, nil, errors.New("existed user") - } + // Check if user already exists + existed, err := i.userRepo.FindByAuth0Sub(ctx, *inp.Sub) + if err != nil && !errors.Is(err, rerror.ErrNotFound) { + return nil, nil, err + } + if existed != nil { + return nil, nil, errors.New("existed user") + } - if inp.UserID != nil { - existed, err := i.userRepo.FindByID(ctx, *inp.UserID) + if inp.UserID != nil { + existed, err := i.userRepo.FindByID(ctx, *inp.UserID) + if err != nil && !errors.Is(err, rerror.ErrNotFound) { + return nil, nil, err + } + if existed != nil { + return nil, nil, errors.New("existed user") + } + } + + // Fetch user info + ui, err := i.authenticator.FetchUser(*inp.Sub) + if err != nil { + return nil, nil, err + } + + // Check if user and team already exists + existed, err = i.userRepo.FindByEmail(ctx, ui.Email) if err != nil && !errors.Is(err, rerror.ErrNotFound) { return nil, nil, err } if existed != nil { return nil, nil, errors.New("existed user") } + name = ui.Name + email = ui.Email + auth = user.AuthFromAuth0Sub(*inp.Sub).Ref() + + } else if inp.Name != nil && inp.Email != nil && inp.Password != nil { + if *inp.Name == "" { + return nil, nil, interfaces.ErrSignupInvalidName + } + if _, err := mail.ParseAddress(*inp.Email); err != nil { + return nil, nil, interfaces.ErrInvalidUserEmail + } + if *inp.Password == "" { + return nil, nil, interfaces.ErrSignupInvalidPassword + } + + tx, err = i.transaction.Begin() + if err != nil { + return + } + defer func() { + if err2 := tx.End(ctx); err == nil && err2 != nil { + err = err2 + } + }() + + // Check if user email already exists + existed, err := i.userRepo.FindByEmail(ctx, *inp.Email) + if err != nil && !errors.Is(err, rerror.ErrNotFound) { + return nil, nil, err + } + if existed != nil { + return nil, nil, errors.New("existed user email") + } + + name = *inp.Name + email = *inp.Email } + // Check if team already exists if inp.TeamID != nil { existed, err := i.teamRepo.FindByID(ctx, *inp.TeamID) if err != nil && !errors.Is(err, rerror.ErrNotFound) { @@ -126,27 +214,12 @@ func (i *User) Signup(ctx context.Context, inp interfaces.SignupParam) (u *user. } } - // Fetch user info - ui, err := i.authenticator.FetchUser(inp.Sub) - if err != nil { - return nil, nil, err - } - - // Check if user and team already exists - var team *user.Team - existed, err = i.userRepo.FindByEmail(ctx, ui.Email) - if err != nil && !errors.Is(err, rerror.ErrNotFound) { - return nil, nil, err - } - if existed != nil { - return nil, nil, errors.New("existed user") - } - // Initialize user and team u, team, err = user.Init(user.InitParams{ - Email: ui.Email, - Name: ui.Name, - Auth0Sub: inp.Sub, + Email: name, + Name: email, + Sub: auth, + Password: *inp.Password, Lang: inp.Lang, Theme: inp.Theme, UserID: inp.UserID, @@ -161,8 +234,10 @@ func (i *User) Signup(ctx context.Context, inp interfaces.SignupParam) (u *user. if err := i.teamRepo.Save(ctx, team); err != nil { return nil, nil, err } + if tx != nil { + tx.Commit() + } - tx.Commit() return u, team, nil } @@ -171,11 +246,14 @@ func (i *User) GetUserByCredentials(ctx context.Context, inp interfaces.GetUserB if err != nil && !errors.Is(rerror.ErrNotFound, err) { return nil, err } else if u == nil { - return nil, interfaces.ErrInvalidUserCredentials + return nil, interfaces.ErrInvalidUserEmail } - // TODO: Check user password - if inp.Password != "123123123" { - return nil, interfaces.ErrInvalidUserCredentials + matched, err := u.MatchPassword(inp.Password) + if err != nil { + return nil, err + } + if !matched { + return nil, interfaces.ErrSignupInvalidPassword } return u, nil } @@ -188,6 +266,91 @@ func (i *User) GetUserBySubject(ctx context.Context, sub string) (u *user.User, return u, nil } +func (i *User) StartPasswordReset(ctx context.Context, email string) error { + tx, err := i.transaction.Begin() + if err != nil { + return err + } + defer func() { + if err2 := tx.End(ctx); err == nil && err2 != nil { + err = err2 + } + }() + + u, err := i.userRepo.FindByEmail(ctx, email) + if err != nil { + return err + } + + pr := user.NewPasswordReset() + u.SetPasswordReset(pr) + + if err := i.userRepo.Save(ctx, u); err != nil { + return err + } + + var TextOut, HTMLOut bytes.Buffer + link := "localhost:3000/?pwd-reset-token=" + pr.Token + err = passwordResetTextTMPL.Execute(&TextOut, link) + if err != nil { + return err + } + err = passwordResetHTMLTMPL.Execute(&HTMLOut, link) + if err != nil { + return err + } + + err = i.mailer.SendMail([]gateway.Contact{ + { + Email: u.Email(), + Name: u.Name(), + }, + }, "Password reset", TextOut.String(), HTMLOut.String()) + if err != nil { + return err + } + + tx.Commit() + return nil +} + +func (i *User) PasswordReset(ctx context.Context, password, token string) error { + tx, err := i.transaction.Begin() + if err != nil { + return err + } + defer func() { + if err2 := tx.End(ctx); err == nil && err2 != nil { + err = err2 + } + }() + + u, err := i.userRepo.FindByPasswordResetRequest(ctx, token) + if err != nil { + return err + } + + passwordReset := u.PasswordReset() + ok := passwordReset.Validate(token) + + if !ok { + return interfaces.ErrUserInvalidPasswordReset + } + + u.SetPasswordReset(nil) + + if err := u.SetPassword(password); err != nil { + return err + } + + if err := i.userRepo.Save(ctx, u); err != nil { + return err + } + + tx.Commit() + return nil +} + func (i *User) UpdateMe(ctx context.Context, p interfaces.UpdateMeParam, operator *usecase.Operator) (u *user.User, err error) { if err := i.OnlyOperator(operator); err != nil { return nil, err diff --git a/internal/usecase/interfaces/user.go b/internal/usecase/interfaces/user.go index f476a79a..acca0bb5 100644 --- a/internal/usecase/interfaces/user.go +++ b/internal/usecase/interfaces/user.go @@ -13,18 +13,24 @@ import ( var ( ErrUserInvalidPasswordConfirmation = errors.New("invalid password confirmation") + ErrUserInvalidPasswordReset = errors.New("invalid password reset request") ErrUserInvalidLang = errors.New("invalid lang") ErrSignupInvalidSecret = errors.New("invalid secret") - ErrInvalidUserCredentials = errors.New("invalid credentials") + ErrSignupInvalidName = errors.New("invalid name") + ErrInvalidUserEmail = errors.New("invalid email") + ErrSignupInvalidPassword = errors.New("invalid password") ) type SignupParam struct { - Sub string - Lang *language.Tag - Theme *user.Theme - UserID *id.UserID - TeamID *id.TeamID - Secret string + Sub *string + UserID *id.UserID + Secret *string + Name *string + Email *string + Password *string + Lang *language.Tag + Theme *user.Theme + TeamID *id.TeamID } type GetUserByCredentials struct { @@ -48,6 +54,8 @@ type User interface { VerifyUser(context.Context, string) (*user.User, error) GetUserByCredentials(context.Context, GetUserByCredentials) (*user.User, error) GetUserBySubject(context.Context, string) (*user.User, error) + StartPasswordReset(context.Context, string) error + PasswordReset(context.Context, string, string) error UpdateMe(context.Context, UpdateMeParam, *usecase.Operator) (*user.User, error) RemoveMyAuth(context.Context, string, *usecase.Operator) (*user.User, error) SearchUser(context.Context, string, *usecase.Operator) (*user.User, error) diff --git a/internal/usecase/repo/user.go b/internal/usecase/repo/user.go index 472aef9d..fdc3f276 100644 --- a/internal/usecase/repo/user.go +++ b/internal/usecase/repo/user.go @@ -14,6 +14,7 @@ type User interface { FindByEmail(context.Context, string) (*user.User, error) FindByNameOrEmail(context.Context, string) (*user.User, error) FindByVerification(context.Context, string) (*user.User, error) + FindByPasswordResetRequest(context.Context, string) (*user.User, error) Save(context.Context, *user.User) error Remove(context.Context, id.UserID) error } diff --git a/pkg/log/log.go b/pkg/log/log.go index 23ed1c84..a42a6f00 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -102,3 +102,7 @@ func Errorln(args ...interface{}) { func Fatalln(args ...interface{}) { logrus.Fatalln(args...) } + +func Panicf(format string, args ...interface{}) { + logrus.Panicf(format, args...) +} diff --git a/pkg/user/auth.go b/pkg/user/auth.go index ed149859..08c3dd50 100644 --- a/pkg/user/auth.go +++ b/pkg/user/auth.go @@ -1,6 +1,8 @@ package user -import "strings" +import ( + "strings" +) type Auth struct { Provider string @@ -18,3 +20,15 @@ func AuthFromAuth0Sub(sub string) Auth { func (a Auth) IsAuth0() bool { return a.Provider == "auth0" } + +func (a Auth) Ref() *Auth { + a2 := a + return &a2 +} + +func GenReearthSub(userID string) *Auth { + return &Auth{ + Provider: "reearth", + Sub: userID, + } +} diff --git a/pkg/user/builder.go b/pkg/user/builder.go index d9a3f31e..c6e265bf 100644 --- a/pkg/user/builder.go +++ b/pkg/user/builder.go @@ -6,7 +6,8 @@ import ( ) type Builder struct { - u *User + u *User + passwordText string } func New() *Builder { @@ -17,6 +18,11 @@ func (b *Builder) Build() (*User, error) { if id.ID(b.u.id).IsNil() { return nil, id.ErrInvalidID } + if b.passwordText != "" { + if err := b.u.SetPassword(b.passwordText); err != nil { + return nil, ErrEncodingPassword + } + } return b.u, nil } @@ -48,6 +54,20 @@ func (b *Builder) Email(email string) *Builder { return b } +func (b *Builder) Password(p []byte) *Builder { + if p == nil { + b.u.password = nil + } else { + b.u.password = append(p[:0:0], p...) + } + return b +} + +func (b *Builder) PasswordPlainText(p string) *Builder { + b.passwordText = p + return b +} + func (b *Builder) Team(team id.TeamID) *Builder { b.u.team = team return b @@ -77,6 +97,11 @@ func (b *Builder) Auths(auths []Auth) *Builder { return b } +func (b *Builder) PasswordReset(pr *PasswordReset) *Builder { + b.u.passwordReset = pr + return b +} + func (b *Builder) Verification(v *Verification) *Builder { b.u.verification = v return b diff --git a/pkg/user/builder_test.go b/pkg/user/builder_test.go index d87ef112..76b34085 100644 --- a/pkg/user/builder_test.go +++ b/pkg/user/builder_test.go @@ -14,6 +14,7 @@ func TestBuilder_ID(t *testing.T) { uid := id.NewUserID() b := New().ID(uid).MustBuild() assert.Equal(t, uid, b.ID()) + assert.Nil(t, b.passwordReset) } func TestBuilder_Name(t *testing.T) { @@ -88,6 +89,31 @@ func TestBuilder_LangFrom(t *testing.T) { } } +func TestBuilder_PasswordReset(t *testing.T) { + testCases := []struct { + Name, Token string + CreatedAt time.Time + Expected PasswordReset + }{ + { + Name: "Test1", + Token: "xyz", + CreatedAt: time.Unix(0, 0), + Expected: PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(0, 0), + }, + }, + } + for _, tc := range testCases { + t.Run(tc.Name, func(tt *testing.T) { + tt.Parallel() + // u := New().NewID().PasswordReset(tc.Token, tc.CreatedAt).MustBuild() + // assert.Equal(t, tc.Expected, *u.passwordReset) + }) + } +} + func TestNew(t *testing.T) { b := New() assert.NotNil(t, b) @@ -97,37 +123,48 @@ func TestNew(t *testing.T) { func TestBuilder_Build(t *testing.T) { uid := id.NewUserID() tid := id.NewTeamID() + pass, _ := encodePassword("pass") + testCases := []struct { Name, UserName, Lang, Email string UID id.UserID TID id.TeamID Auths []Auth + PasswordBin []byte Expected *User err error }{ { - Name: "Success build user", - UserName: "xxx", - Email: "xx@yy.zz", - Lang: "en", - UID: uid, - TID: tid, + Name: "Success build user", + UserName: "xxx", + Email: "xx@yy.zz", + Lang: "en", + UID: uid, + PasswordBin: pass, + TID: tid, Auths: []Auth{ { Provider: "ppp", Sub: "sss", }, }, - Expected: New(). - ID(uid). - Team(tid). - Email("xx@yy.zz"). - Name("xxx"). - Auths([]Auth{{Provider: "ppp", Sub: "sss"}}). - LangFrom("en"). - MustBuild(), + Expected: &User{ + id: uid, + name: "xxx", + email: "xx@yy.zz", + password: pass, + team: tid, + auths: []Auth{ + { + Provider: "ppp", + Sub: "sss", + }, + }, + lang: language.MustParse("en"), + }, err: nil, - }, { + }, + { Name: "failed invalid id", Expected: nil, err: id.ErrInvalidID, @@ -136,7 +173,7 @@ func TestBuilder_Build(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(tt *testing.T) { tt.Parallel() - res, err := New().ID(tc.UID).Name(tc.UserName).Auths(tc.Auths).LangFrom(tc.Lang).Email(tc.Email).Team(tc.TID).Build() + res, err := New().ID(tc.UID).Name(tc.UserName).Auths(tc.Auths).LangFrom(tc.Lang).Email(tc.Email).Password(tc.PasswordBin).Team(tc.TID).Build() if err == nil { assert.Equal(tt, tc.Expected, res) } else { @@ -149,37 +186,48 @@ func TestBuilder_Build(t *testing.T) { func TestBuilder_MustBuild(t *testing.T) { uid := id.NewUserID() tid := id.NewTeamID() + pass, _ := encodePassword("pass") testCases := []struct { Name, UserName, Lang, Email string UID id.UserID TID id.TeamID + PasswordBin []byte Auths []Auth Expected *User err error }{ + { - Name: "Success build user", - UserName: "xxx", - Email: "xx@yy.zz", - Lang: "en", - UID: uid, - TID: tid, + Name: "Success build user", + UserName: "xxx", + Email: "xx@yy.zz", + Lang: "en", + UID: uid, + PasswordBin: pass, + TID: tid, Auths: []Auth{ { Provider: "ppp", Sub: "sss", }, }, - Expected: New(). - ID(uid). - Team(tid). - Email("xx@yy.zz"). - Name("xxx"). - Auths([]Auth{{Provider: "ppp", Sub: "sss"}}). - LangFrom("en"). - MustBuild(), + Expected: &User{ + id: uid, + name: "xxx", + email: "xx@yy.zz", + password: pass, + team: tid, + auths: []Auth{ + { + Provider: "ppp", + Sub: "sss", + }, + }, + lang: language.MustParse("en"), + }, err: nil, - }, { + }, + { Name: "failed invalid id", Expected: nil, err: id.ErrInvalidID, @@ -195,7 +243,15 @@ func TestBuilder_MustBuild(t *testing.T) { } }() - res = New().ID(tc.UID).Name(tc.UserName).Auths(tc.Auths).LangFrom(tc.Lang).Email(tc.Email).Team(tc.TID).MustBuild() + res = New(). + ID(tc.UID). + Name(tc.UserName). + Auths(tc.Auths). + Password(tc.PasswordBin). + LangFrom(tc.Lang). + Email(tc.Email). + Team(tc.TID). + MustBuild() }) } } diff --git a/pkg/user/initializer.go b/pkg/user/initializer.go index f28c19a1..0ed23272 100644 --- a/pkg/user/initializer.go +++ b/pkg/user/initializer.go @@ -8,7 +8,8 @@ import ( type InitParams struct { Email string Name string - Auth0Sub string + Sub *Auth + Password string Lang *language.Tag Theme *Theme UserID *id.UserID @@ -29,13 +30,17 @@ func Init(p InitParams) (*User, *Team, error) { t := ThemeDefault p.Theme = &t } + if p.Sub == nil { + p.Sub = GenReearthSub(p.UserID.String()) + } u, err := New(). ID(*p.UserID). Name(p.Name). Email(p.Email). - Auths([]Auth{AuthFromAuth0Sub(p.Auth0Sub)}). + Auths([]Auth{*p.Sub}). Lang(*p.Lang). + PasswordPlainText(p.Password). Theme(*p.Theme). Build() if err != nil { diff --git a/pkg/user/initializer_test.go b/pkg/user/initializer_test.go index 769b87ab..2749b4c6 100644 --- a/pkg/user/initializer_test.go +++ b/pkg/user/initializer_test.go @@ -11,27 +11,35 @@ import ( func TestInit(t *testing.T) { uid := id.NewUserID() tid := id.NewTeamID() + expectedSub := Auth{ + Provider: "###", + Sub: "###", + } testCases := []struct { - Name, Email, Username, Sub string - UID *id.UserID - TID *id.TeamID - ExpectedUser *User - ExpectedTeam *Team - Err error + Name, Email, Username string + Sub Auth + UID *id.UserID + TID *id.TeamID + ExpectedUser *User + ExpectedTeam *Team + Err error }{ { Name: "Success create user", Email: "xx@yy.zz", Username: "nnn", - Sub: "###", - UID: &uid, - TID: &tid, + Sub: Auth{ + Provider: "###", + Sub: "###", + }, + UID: &uid, + TID: &tid, ExpectedUser: New(). ID(uid). Email("xx@yy.zz"). Name("nnn"). Team(tid). - Auths([]Auth{AuthFromAuth0Sub("###")}). + Auths([]Auth{expectedSub}). MustBuild(), ExpectedTeam: NewTeam(). ID(tid). @@ -45,15 +53,18 @@ func TestInit(t *testing.T) { Name: "Success nil team id", Email: "xx@yy.zz", Username: "nnn", - Sub: "###", - UID: &uid, - TID: nil, + Sub: Auth{ + Provider: "###", + Sub: "###", + }, + UID: &uid, + TID: nil, ExpectedUser: New(). ID(uid). Email("xx@yy.zz"). Name("nnn"). Team(tid). - Auths([]Auth{AuthFromAuth0Sub("###")}). + Auths([]Auth{expectedSub}). MustBuild(), ExpectedTeam: NewTeam(). NewID(). @@ -67,15 +78,18 @@ func TestInit(t *testing.T) { Name: "Success nil id", Email: "xx@yy.zz", Username: "nnn", - Sub: "###", - UID: nil, - TID: &tid, + Sub: Auth{ + Provider: "###", + Sub: "###", + }, + UID: nil, + TID: &tid, ExpectedUser: New(). NewID(). Email("xx@yy.zz"). Name("nnn"). Team(tid). - Auths([]Auth{AuthFromAuth0Sub("###")}). + Auths([]Auth{expectedSub}). MustBuild(), ExpectedTeam: NewTeam(). ID(tid). @@ -91,11 +105,11 @@ func TestInit(t *testing.T) { t.Run(tc.Name, func(tt *testing.T) { tt.Parallel() u, t, err := Init(InitParams{ - Email: tc.Email, - Name: tc.Username, - Auth0Sub: tc.Sub, - UserID: tc.UID, - TeamID: tc.TID, + Email: tc.Email, + Name: tc.Username, + Sub: &tc.Sub, + UserID: tc.UID, + TeamID: tc.TID, }) if err == nil { assert.Equal(tt, tc.ExpectedUser.Email(), u.Email()) diff --git a/pkg/user/password_reset.go b/pkg/user/password_reset.go new file mode 100644 index 00000000..6ec20872 --- /dev/null +++ b/pkg/user/password_reset.go @@ -0,0 +1,44 @@ +package user + +import ( + "time" + + "github.com/google/uuid" +) + +var timeNow = time.Now + +type PasswordReset struct { + Token string + CreatedAt time.Time +} + +func NewPasswordReset() *PasswordReset { + return &PasswordReset{ + Token: generateToken(), + CreatedAt: timeNow(), + } +} + +func PasswordResetFrom(token string, createdAt time.Time) *PasswordReset { + return &PasswordReset{ + Token: token, + CreatedAt: createdAt, + } +} + +func generateToken() string { + return uuid.New().String() +} + +func (pr *PasswordReset) Validate(token string) bool { + return pr != nil && pr.Token == token && pr.CreatedAt.Add(24*time.Hour).After(time.Now()) +} + +func (pr *PasswordReset) Clone() *PasswordReset { + if pr == nil { + return nil + } + pr2 := PasswordResetFrom(pr.Token, pr.CreatedAt) + return pr2 +} diff --git a/pkg/user/password_reset_test.go b/pkg/user/password_reset_test.go new file mode 100644 index 00000000..253a7b92 --- /dev/null +++ b/pkg/user/password_reset_test.go @@ -0,0 +1,103 @@ +package user + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewPasswordReset(t *testing.T) { + mockTime := time.Now() + timeNow = func() time.Time { + return mockTime + } + pr := NewPasswordReset() + assert.NotNil(t, pr) + assert.NotEmpty(t, pr.Token) + assert.Equal(t, mockTime, pr.CreatedAt) +} + +func TestPasswordReset_Validate(t *testing.T) { + tests := []struct { + name string + pr *PasswordReset + token string + want bool + }{ + { + name: "valid", + pr: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Now(), + }, + token: "xyz", + want: true, + }, + { + name: "wrong token", + pr: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Now(), + }, + token: "xxx", + want: false, + }, + { + name: "old request", + pr: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Now().Add(-24 * time.Hour), + }, + token: "xyz", + want: false, + }, + { + name: "nil request", + pr: nil, + token: "xyz", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.pr.Validate(tt.token)) + }) + } +} + +func Test_generateToken(t *testing.T) { + t1 := generateToken() + t2 := generateToken() + + assert.NotNil(t, t1) + assert.NotNil(t, t2) + assert.NotEmpty(t, t1) + assert.NotEmpty(t, t2) + assert.NotEqual(t, t1, t2) + +} + +func TestPasswordResetFrom(t *testing.T) { + tests := []struct { + name string + token string + createdAt time.Time + want *PasswordReset + }{ + { + name: "prFrom", + token: "xyz", + createdAt: time.Unix(1, 1), + want: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, PasswordResetFrom(tt.token, tt.createdAt)) + }) + } +} diff --git a/pkg/user/user.go b/pkg/user/user.go index f5946b02..dcead298 100644 --- a/pkg/user/user.go +++ b/pkg/user/user.go @@ -1,19 +1,30 @@ package user import ( + "errors" + + "golang.org/x/crypto/bcrypt" + "github.com/reearth/reearth-backend/pkg/id" "golang.org/x/text/language" ) +var ( + ErrEncodingPassword = errors.New("error encoding password") + ErrInvalidPassword = errors.New("error invalid password") +) + type User struct { - id id.UserID - name string - email string - team id.TeamID - auths []Auth - lang language.Tag - theme Theme - verification *Verification + id id.UserID + name string + email string + password []byte + team id.TeamID + auths []Auth + lang language.Tag + theme Theme + verification *Verification + passwordReset *PasswordReset } func (u *User) ID() id.UserID { @@ -40,6 +51,10 @@ func (u *User) Theme() Theme { return u.theme } +func (u *User) Password() []byte { + return u.password +} + func (u *User) UpdateName(name string) { u.name = name } @@ -136,6 +151,46 @@ func (u *User) ClearAuths() { u.auths = []Auth{} } +func (u *User) SetPassword(pass string) error { + p, err := encodePassword(pass) + if err != nil { + return err + } + u.password = p + return nil +} + +func (u *User) MatchPassword(pass string) (bool, error) { + if u == nil || len(u.password) == 0 { + return false, nil + } + return verifyPassword(pass, u.password) +} + +func encodePassword(pass string) ([]byte, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(pass), 14) + return bytes, err +} + +func verifyPassword(toVerify string, encoded []byte) (bool, error) { + err := bcrypt.CompareHashAndPassword(encoded, []byte(toVerify)) + if err != nil { + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return false, nil + } + return false, err + } + return true, nil +} + +func (u *User) PasswordReset() *PasswordReset { + return u.passwordReset +} + +func (u *User) SetPasswordReset(pr *PasswordReset) { + u.passwordReset = pr.Clone() +} + func (u *User) SetVerification(v *Verification) { u.verification = v } diff --git a/pkg/user/user_test.go b/pkg/user/user_test.go index fcebb5fa..a8c9338c 100644 --- a/pkg/user/user_test.go +++ b/pkg/user/user_test.go @@ -317,6 +317,180 @@ func TestUser_GetAuthByProvider(t *testing.T) { } } +func TestUser_MatchPassword(t *testing.T) { + encodedPass, _ := encodePassword("test") + type args struct { + pass string + } + tests := []struct { + name string + password []byte + args args + want bool + wantErr bool + }{ + { + name: "passwords should match", + password: encodedPass, + args: args{ + pass: "test", + }, + want: true, + wantErr: false, + }, + { + name: "passwords shouldn't match", + password: encodedPass, + args: args{ + pass: "xxx", + }, + want: false, + wantErr: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(tt *testing.T) { + u := &User{ + password: tc.password, + } + got, err := u.MatchPassword(tc.args.pass) + assert.Equal(tt, tc.want, got) + if tc.wantErr { + assert.Error(tt, err) + } else { + assert.NoError(tt, err) + } + }) + } +} + +func TestUser_SetPassword(t *testing.T) { + type args struct { + pass string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "should set the password", + args: args{ + pass: "test", + }, + want: "test", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(tt *testing.T) { + u := &User{} + _ = u.SetPassword(tc.args.pass) + got, err := verifyPassword(tc.want, u.password) + assert.NoError(tt, err) + assert.True(tt, got) + }) + } +} + +func TestUser_PasswordReset(t *testing.T) { + testCases := []struct { + Name string + User *User + Expected *PasswordReset + }{ + { + Name: "not password request", + User: New().NewID().MustBuild(), + Expected: nil, + }, + { + Name: "create new password request over existing one", + User: New().NewID().PasswordReset(&PasswordReset{"xzy", time.Unix(0, 0)}).MustBuild(), + Expected: &PasswordReset{ + Token: "xzy", + CreatedAt: time.Unix(0, 0), + }, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(tt *testing.T) { + tt.Parallel() + assert.Equal(tt, tc.Expected, tc.User.PasswordReset()) + }) + } +} + +func TestUser_SetPasswordReset(t *testing.T) { + tests := []struct { + Name string + User *User + Pr *PasswordReset + Expected *PasswordReset + }{ + { + Name: "nil", + User: New().NewID().MustBuild(), + Pr: nil, + Expected: nil, + }, + { + Name: "nil", + User: New().NewID().MustBuild(), + Pr: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + Expected: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + }, + { + Name: "create new password request", + User: New().NewID().MustBuild(), + Pr: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + Expected: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + }, + { + Name: "create new password request over existing one", + User: New().NewID().PasswordReset(&PasswordReset{"xzy", time.Now()}).MustBuild(), + Pr: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + Expected: &PasswordReset{ + Token: "xyz", + CreatedAt: time.Unix(1, 1), + }, + }, + { + Name: "remove none existing password request", + User: New().NewID().MustBuild(), + Pr: nil, + Expected: nil, + }, + { + Name: "remove existing password request", + User: New().NewID().PasswordReset(&PasswordReset{"xzy", time.Now()}).MustBuild(), + Pr: nil, + Expected: nil, + }, + } + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + tt.User.SetPasswordReset(tt.Pr) + assert.Equal(t, tt.Expected, tt.User.PasswordReset()) + }) + } +} + func TestUser_SetVerification(t *testing.T) { input := &User{} v := &Verification{ diff --git a/schema.graphql b/schema.graphql index 31ca85d2..e8d84385 100644 --- a/schema.graphql +++ b/schema.graphql @@ -1140,13 +1140,13 @@ type RemoveAssetPayload { assetId: ID! } -type SignupPayload { +type UpdateMePayload { user: User! - team: Team! } -type UpdateMePayload { +type SignupPayload { user: User! + team: Team! } type DeleteMePayload {