diff --git a/handlers/auth_test.go b/handlers/auth_test.go index c5963ef..de91763 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -13,7 +13,6 @@ import ( api "github.com/alexferl/golib/http/api/server" "github.com/labstack/echo/v4" "github.com/spf13/viper" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -82,7 +81,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Login_200() { if viper.GetBool(config.CSRFEnabled) { expected = 3 } - if assert.Equal(s.T(), expected, len(resp.Result().Cookies())) { + if s.Assert().Equal(expected, len(resp.Result().Cookies())) { cookies := 0 for _, c := range resp.Result().Cookies() { if c.Name == viper.GetString(config.JWTAccessTokenCookieName) { @@ -95,14 +94,14 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Login_200() { cookies++ } } - assert.Equal(s.T(), expected, cookies) + s.Assert().Equal(expected, cookies) } - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.NotEqual(s.T(), "", result.AccessToken) - assert.NotEqual(s.T(), "", result.ExpiresIn) - assert.NotEqual(s.T(), "", result.RefreshToken) - assert.NotEqual(s.T(), "", result.TokenType) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().NotEqual("", result.AccessToken) + s.Assert().NotEqual("", result.ExpiresIn) + s.Assert().NotEqual("", result.RefreshToken) + s.Assert().NotEqual("", result.TokenType) }) } } @@ -140,8 +139,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Login_401() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Contains(s.T(), "invalid email or password", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Contains("invalid email or password", result.Message) }) } } @@ -155,7 +154,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Login_400() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusBadRequest, resp.Code) + s.Assert().Equal(http.StatusBadRequest, resp.Code) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_204_Cookie() { @@ -177,7 +176,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_204_Cookie() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusNoContent, resp.Code) + s.Assert().Equal(http.StatusNoContent, resp.Code) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_401_Cookie_Invalid() { @@ -194,8 +193,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_401_Cookie_Invalid() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), jwtMw.ErrTokenInvalid, result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal(jwtMw.ErrTokenInvalid, result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_401_Cookie_Mismatch() { @@ -217,8 +216,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_401_Cookie_Mismatch() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token mismatch", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token mismatch", result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_204_Token() { @@ -244,7 +243,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_204_Token() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusNoContent, resp.Code) + s.Assert().Equal(http.StatusNoContent, resp.Code) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_400_Body_Missing_Key() { @@ -259,8 +258,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_400_Body_Missing_Key() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) - assert.Equal(s.T(), jwtMw.ErrBodyMissingKey, result.Message) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(jwtMw.ErrBodyMissingKey, result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_400_Token_Missing() { @@ -273,8 +272,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_400_Token_Missing() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusBadRequest, resp.Code) - assert.Equal(s.T(), jwtMw.ErrRequestMalformed, result.Message) + s.Assert().Equal(http.StatusBadRequest, resp.Code) + s.Assert().Equal(jwtMw.ErrRequestMalformed, result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_401_Token_Mismatch() { @@ -300,8 +299,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Logout_401_Token_Mismatch() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token mismatch", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token mismatch", result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_200_Cookie() { @@ -335,7 +334,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_200_Cookie() { if viper.GetBool(config.CSRFEnabled) { expected = 3 } - if assert.Equal(s.T(), expected, len(resp.Result().Cookies())) { + if s.Assert().Equal(expected, len(resp.Result().Cookies())) { cookies := 0 for _, c := range resp.Result().Cookies() { if c.Name == viper.GetString(config.JWTAccessTokenCookieName) { @@ -348,14 +347,14 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_200_Cookie() { cookies++ } } - assert.Equal(s.T(), expected, cookies) + s.Assert().Equal(expected, cookies) } - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.NotEqual(s.T(), "", result.AccessToken) - assert.NotEqual(s.T(), "", result.ExpiresIn) - assert.NotEqual(s.T(), "", result.RefreshToken) - assert.NotEqual(s.T(), "", result.TokenType) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().NotEqual("", result.AccessToken) + s.Assert().NotEqual("", result.ExpiresIn) + s.Assert().NotEqual("", result.RefreshToken) + s.Assert().NotEqual("", result.TokenType) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_400_Cookie_Missing() { @@ -368,8 +367,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_400_Cookie_Missing() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusBadRequest, resp.Code) - assert.Equal(s.T(), jwtMw.ErrRequestMalformed, result.Message) + s.Assert().Equal(http.StatusBadRequest, resp.Code) + s.Assert().Equal(jwtMw.ErrRequestMalformed, result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_401_Cookie_Invalid() { @@ -391,8 +390,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_401_Cookie_Invalid() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token mismatch", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token mismatch", result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_200_Token() { @@ -426,7 +425,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_200_Token() { if viper.GetBool(config.CSRFEnabled) { expected = 3 } - if assert.Equal(s.T(), expected, len(resp.Result().Cookies())) { + if s.Assert().Equal(expected, len(resp.Result().Cookies())) { cookies := 0 for _, c := range resp.Result().Cookies() { if c.Name == viper.GetString(config.JWTAccessTokenCookieName) { @@ -439,14 +438,14 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_200_Token() { cookies++ } } - assert.Equal(s.T(), expected, cookies) + s.Assert().Equal(expected, cookies) } - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.NotEqual(s.T(), "", result.AccessToken) - assert.NotEqual(s.T(), "", result.ExpiresIn) - assert.NotEqual(s.T(), "", result.RefreshToken) - assert.NotEqual(s.T(), "", result.TokenType) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().NotEqual("", result.AccessToken) + s.Assert().NotEqual("", result.ExpiresIn) + s.Assert().NotEqual("", result.RefreshToken) + s.Assert().NotEqual("", result.TokenType) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_400_Token_Missing() { @@ -459,8 +458,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_400_Token_Missing() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusBadRequest, resp.Code) - assert.Equal(s.T(), jwtMw.ErrRequestMalformed, result.Message) + s.Assert().Equal(http.StatusBadRequest, resp.Code) + s.Assert().Equal(jwtMw.ErrRequestMalformed, result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_401_Token_Invalid() { @@ -479,8 +478,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_401_Token_Invalid() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token invalid", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token invalid", result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_401_Token_Mismatch() { @@ -507,8 +506,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Refresh_401_Token_Mismatch() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token mismatch", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token mismatch", result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Signup_200() { @@ -539,7 +538,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Signup_200() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusOK, resp.Code) + s.Assert().Equal(http.StatusOK, resp.Code) } func (s *AuthHandlerTestSuite) TestAuthHandler_Signup_409() { @@ -567,8 +566,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Signup_409() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusConflict, resp.Code) - assert.Equal(s.T(), services.ErrUserExist.Error(), result.Message) + s.Assert().Equal(http.StatusConflict, resp.Code) + s.Assert().Equal(services.ErrUserExist.Error(), result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Signup_422() { @@ -581,7 +580,7 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Signup_422() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) } func (s *AuthHandlerTestSuite) TestAuthHandler_Token_200() { @@ -606,13 +605,13 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Token_200() { typ, _ := token.Get("type") - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), token.Expiration(), result.Exp) - assert.Equal(s.T(), token.IssuedAt(), result.Iat) - assert.Equal(s.T(), token.Issuer(), result.Iss) - assert.Equal(s.T(), token.NotBefore(), result.Nbf) - assert.Equal(s.T(), token.Subject(), result.Sub) - assert.Equal(s.T(), typ, result.Type) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(token.Expiration(), result.Exp) + s.Assert().Equal(token.IssuedAt(), result.Iat) + s.Assert().Equal(token.Issuer(), result.Iss) + s.Assert().Equal(token.NotBefore(), result.Nbf) + s.Assert().Equal(token.Subject(), result.Sub) + s.Assert().Equal(typ, result.Type) } func (s *AuthHandlerTestSuite) TestAuthHandler_Token_401() { @@ -625,8 +624,8 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Token_401() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token invalid", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token invalid", result.Message) } func (s *AuthHandlerTestSuite) TestAuthHandler_Cookie_200() { @@ -651,13 +650,13 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Cookie_200() { typ, _ := token.Get("type") - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), token.Expiration(), result.Exp) - assert.Equal(s.T(), token.IssuedAt(), result.Iat) - assert.Equal(s.T(), token.Issuer(), result.Iss) - assert.Equal(s.T(), token.NotBefore(), result.Nbf) - assert.Equal(s.T(), token.Subject(), result.Sub) - assert.Equal(s.T(), typ, result.Type) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(token.Expiration(), result.Exp) + s.Assert().Equal(token.IssuedAt(), result.Iat) + s.Assert().Equal(token.Issuer(), result.Iss) + s.Assert().Equal(token.NotBefore(), result.Nbf) + s.Assert().Equal(token.Subject(), result.Sub) + s.Assert().Equal(typ, result.Type) } func (s *AuthHandlerTestSuite) TestAuthHandler_Cookie_401() { @@ -671,6 +670,6 @@ func (s *AuthHandlerTestSuite) TestAuthHandler_Cookie_401() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token invalid", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token invalid", result.Message) } diff --git a/handlers/personal_access_token_test.go b/handlers/personal_access_token_test.go index 5230191..1dd7656 100644 --- a/handlers/personal_access_token_test.go +++ b/handlers/personal_access_token_test.go @@ -12,7 +12,6 @@ import ( "github.com/alexferl/echo-openapi" api "github.com/alexferl/golib/http/api/server" "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -97,7 +96,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Cre var result models.PersonalAccessToken _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) + s.Assert().Equal(http.StatusOK, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Create_401() { @@ -107,7 +106,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Cre s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Create_409() { @@ -135,7 +134,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Cre s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusConflict, resp.Code) + s.Assert().Equal(http.StatusConflict, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Create_422() { @@ -157,7 +156,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Cre s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Create_422_Exp() { @@ -185,7 +184,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Cre s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_List_200() { @@ -211,8 +210,8 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Lis var result models.PersonalAccessTokensResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), num, len(result.Tokens)) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(num, len(result.Tokens)) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_List_401() { @@ -222,7 +221,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Lis s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Get_200() { @@ -251,7 +250,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Get var result models.PersonalAccessToken _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) + s.Assert().Equal(http.StatusOK, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Get_404() { @@ -277,8 +276,8 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Get var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusNotFound, resp.Code) - assert.Equal(s.T(), services.ErrPersonalAccessTokenNotFound.Error(), result.Message) + s.Assert().Equal(http.StatusNotFound, resp.Code) + s.Assert().Equal(services.ErrPersonalAccessTokenNotFound.Error(), result.Message) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Revoke_204() { @@ -308,7 +307,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Rev s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusNoContent, resp.Code) + s.Assert().Equal(http.StatusNoContent, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Revoke_401() { @@ -318,7 +317,7 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Rev s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Revoke_404() { @@ -344,8 +343,8 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Rev var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusNotFound, resp.Code) - assert.Equal(s.T(), services.ErrPersonalAccessTokenNotFound.Error(), result.Message) + s.Assert().Equal(http.StatusNotFound, resp.Code) + s.Assert().Equal(services.ErrPersonalAccessTokenNotFound.Error(), result.Message) } func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Revoke_409() { @@ -372,5 +371,5 @@ func (s *PersonalAccessTokenHandlerTestSuite) TestPersonalAccessTokenHandler_Rev s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusConflict, resp.Code) + s.Assert().Equal(http.StatusConflict, resp.Code) } diff --git a/handlers/setup_test.go b/handlers/setup_test.go index cb323e0..cbb399b 100644 --- a/handlers/setup_test.go +++ b/handlers/setup_test.go @@ -1,11 +1,12 @@ package handlers_test import ( + api "github.com/alexferl/golib/http/api/server" + "github.com/alexferl/echo-boilerplate/handlers" "github.com/alexferl/echo-boilerplate/models" "github.com/alexferl/echo-boilerplate/server" _ "github.com/alexferl/echo-boilerplate/testing" - api "github.com/alexferl/golib/http/api/server" ) func getUser() *models.User { diff --git a/handlers/task_test.go b/handlers/task_test.go index 08e643c..4bc475a 100644 --- a/handlers/task_test.go +++ b/handlers/task_test.go @@ -11,7 +11,6 @@ import ( "github.com/alexferl/echo-openapi" api "github.com/alexferl/golib/http/api/server" "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -72,8 +71,8 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Get_200() { var result models.TaskResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), s.user.Id, result.CreatedBy.Id) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(s.user.Id, result.CreatedBy.Id) } func (s *TaskHandlerTestSuite) TestTaskHandler_401() { @@ -99,8 +98,8 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_401() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token invalid", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token invalid", result.Message) }) } } @@ -151,8 +150,8 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_404() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusNotFound, resp.Code) - assert.Equal(s.T(), services.ErrTaskNotFound.Error(), result.Message) + s.Assert().Equal(http.StatusNotFound, resp.Code) + s.Assert().Equal(services.ErrTaskNotFound.Error(), result.Message) }) } } @@ -203,8 +202,8 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_410() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusGone, resp.Code) - assert.Equal(s.T(), services.ErrTaskDeleted.Error(), result.Message) + s.Assert().Equal(http.StatusGone, resp.Code) + s.Assert().Equal(services.ErrTaskDeleted.Error(), result.Message) }) } } @@ -234,7 +233,7 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_422() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) }) } } @@ -276,9 +275,9 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Update_200() { var result models.TaskResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), title, result.Title) - assert.Equal(s.T(), s.user.Id, result.UpdatedBy.Id) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(title, result.Title) + s.Assert().Equal(s.user.Id, result.UpdatedBy.Id) } func (s *TaskHandlerTestSuite) TestTaskHandler_Update_403() { @@ -309,7 +308,7 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Update_403() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusForbidden, resp.Code) + s.Assert().Equal(http.StatusForbidden, resp.Code) } func (s *TaskHandlerTestSuite) TestTaskHandler_Transition_200() { @@ -349,9 +348,9 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Transition_200() { var result models.TaskResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), s.user.Id, result.CompletedBy.Id) - assert.True(s.T(), result.Completed) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(s.user.Id, result.CompletedBy.Id) + s.Assert().True(result.Completed) } func (s *TaskHandlerTestSuite) TestTaskHandler_Delete_200() { @@ -379,7 +378,7 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Delete_200() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusNoContent, resp.Code) + s.Assert().Equal(http.StatusNoContent, resp.Code) } func (s *TaskHandlerTestSuite) TestTaskHandler_Delete_403() { @@ -404,7 +403,7 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Delete_403() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusForbidden, resp.Code) + s.Assert().Equal(http.StatusForbidden, resp.Code) } func createTasks(num int, user *models.User) models.Tasks { @@ -449,15 +448,15 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_List_200() { `; rel=first, ` + `; rel=prev` - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), num, len(result.Tasks)) - assert.Equal(s.T(), "2", h.Get("X-Page")) - assert.Equal(s.T(), "1", h.Get("X-Per-Page")) - assert.Equal(s.T(), "10", h.Get("X-Total")) - assert.Equal(s.T(), "10", h.Get("X-Total-Pages")) - assert.Equal(s.T(), "3", h.Get("X-Next-Page")) - assert.Equal(s.T(), "1", h.Get("X-Prev-Page")) - assert.Equal(s.T(), link, h.Get("Link")) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(num, len(result.Tasks)) + s.Assert().Equal("2", h.Get("X-Page")) + s.Assert().Equal("1", h.Get("X-Per-Page")) + s.Assert().Equal("10", h.Get("X-Total")) + s.Assert().Equal("10", h.Get("X-Total-Pages")) + s.Assert().Equal("3", h.Get("X-Next-Page")) + s.Assert().Equal("1", h.Get("X-Prev-Page")) + s.Assert().Equal(link, h.Get("Link")) } func (s *TaskHandlerTestSuite) TestTaskHandler_Create_200() { @@ -483,7 +482,7 @@ func (s *TaskHandlerTestSuite) TestTaskHandler_Create_200() { var result models.TaskResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), payload.Title, result.Title) - assert.Equal(s.T(), s.user.Id, result.CreatedBy.Id) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(payload.Title, result.Title) + s.Assert().Equal(s.user.Id, result.CreatedBy.Id) } diff --git a/handlers/user_test.go b/handlers/user_test.go index ac69723..d774d45 100644 --- a/handlers/user_test.go +++ b/handlers/user_test.go @@ -11,7 +11,6 @@ import ( "github.com/alexferl/echo-openapi" api "github.com/alexferl/golib/http/api/server" "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -77,8 +76,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_GetCurrentUser_200() { var result models.UserResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), s.user.Id, result.Id) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(s.user.Id, result.Id) } func (s *UserHandlerTestSuite) TestUserHandler_UpdateCurrentUser_200() { @@ -111,8 +110,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_UpdateCurrentUser_200() { var result models.UserResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), updatedUser.Name, result.Name) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(updatedUser.Name, result.Name) } func (s *UserHandlerTestSuite) TestUserHandler_UpdateCurrentUser_422() { @@ -132,7 +131,7 @@ func (s *UserHandlerTestSuite) TestUserHandler_UpdateCurrentUser_422() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) } func (s *UserHandlerTestSuite) TestUserHandler_Get_200() { @@ -155,8 +154,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_Get_200() { var result models.UserResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), s.user.Id, result.Id) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(s.user.Id, result.Id) } func (s *UserHandlerTestSuite) TestUserHandler_Update_200() { @@ -189,8 +188,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_Update_200() { var result models.UserResponse _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), updatedUser.Name, result.Name) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(updatedUser.Name, result.Name) } func (s *UserHandlerTestSuite) TestUserHandler_Update_404() { @@ -222,8 +221,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_Update_404() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusNotFound, resp.Code) - assert.Equal(s.T(), services.ErrUserNotFound.Error(), result.Message) + s.Assert().Equal(http.StatusNotFound, resp.Code) + s.Assert().Equal(services.ErrUserNotFound.Error(), result.Message) } func (s *UserHandlerTestSuite) TestUserHandler_Update_410() { @@ -256,8 +255,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_Update_410() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusGone, resp.Code) - assert.Equal(s.T(), services.ErrUserDeleted.Error(), result.Message) + s.Assert().Equal(http.StatusGone, resp.Code) + s.Assert().Equal(services.ErrUserDeleted.Error(), result.Message) } func (s *UserHandlerTestSuite) TestUserHandler_204() { @@ -301,7 +300,7 @@ func (s *UserHandlerTestSuite) TestUserHandler_204() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusNoContent, resp.Code) + s.Assert().Equal(http.StatusNoContent, resp.Code) }) } } @@ -334,8 +333,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_401() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnauthorized, resp.Code) - assert.Equal(s.T(), "token invalid", result.Message) + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal("token invalid", result.Message) }) } } @@ -373,7 +372,7 @@ func (s *UserHandlerTestSuite) TestUserHandler_403() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusForbidden, resp.Code) + s.Assert().Equal(http.StatusForbidden, resp.Code) }) } } @@ -415,8 +414,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_404() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusNotFound, resp.Code) - assert.Equal(s.T(), services.ErrUserNotFound.Error(), result.Message) + s.Assert().Equal(http.StatusNotFound, resp.Code) + s.Assert().Equal(services.ErrUserNotFound.Error(), result.Message) }) } } @@ -454,7 +453,7 @@ func (s *UserHandlerTestSuite) TestUserHandler_409() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusConflict, resp.Code) + s.Assert().Equal(http.StatusConflict, resp.Code) }) } } @@ -496,8 +495,8 @@ func (s *UserHandlerTestSuite) TestUserHandler_410() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusGone, resp.Code) - assert.Equal(s.T(), services.ErrUserDeleted.Error(), result.Message) + s.Assert().Equal(http.StatusGone, resp.Code) + s.Assert().Equal(services.ErrUserDeleted.Error(), result.Message) }) } } @@ -527,7 +526,7 @@ func (s *UserHandlerTestSuite) TestUserHandler_Roles_422() { var result echo.HTTPError _ = json.Unmarshal(resp.Body.Bytes(), &result) - assert.Equal(s.T(), http.StatusUnprocessableEntity, resp.Code) + s.Assert().Equal(http.StatusUnprocessableEntity, resp.Code) }) } } @@ -572,15 +571,15 @@ func (s *UserHandlerTestSuite) TestUserHandler_List_200() { `; rel=first, ` + `; rel=prev` - assert.Equal(s.T(), http.StatusOK, resp.Code) - assert.Equal(s.T(), 10, len(result.Users)) - assert.Equal(s.T(), "2", h.Get("X-Page")) - assert.Equal(s.T(), "1", h.Get("X-Per-Page")) - assert.Equal(s.T(), "10", h.Get("X-Total")) - assert.Equal(s.T(), "10", h.Get("X-Total-Pages")) - assert.Equal(s.T(), "3", h.Get("X-Next-Page")) - assert.Equal(s.T(), "1", h.Get("X-Prev-Page")) - assert.Equal(s.T(), link, h.Get("Link")) + s.Assert().Equal(http.StatusOK, resp.Code) + s.Assert().Equal(10, len(result.Users)) + s.Assert().Equal("2", h.Get("X-Page")) + s.Assert().Equal("1", h.Get("X-Per-Page")) + s.Assert().Equal("10", h.Get("X-Total")) + s.Assert().Equal("10", h.Get("X-Total-Pages")) + s.Assert().Equal("3", h.Get("X-Next-Page")) + s.Assert().Equal("1", h.Get("X-Prev-Page")) + s.Assert().Equal(link, h.Get("Link")) } func (s *UserHandlerTestSuite) TestUserHandler_List_403() { @@ -596,5 +595,5 @@ func (s *UserHandlerTestSuite) TestUserHandler_List_403() { s.server.ServeHTTP(resp, req) - assert.Equal(s.T(), http.StatusForbidden, resp.Code) + s.Assert().Equal(http.StatusForbidden, resp.Code) } diff --git a/server/server.go b/server/server.go index cb7a5d0..9f64f2f 100644 --- a/server/server.go +++ b/server/server.go @@ -26,6 +26,17 @@ import ( "github.com/alexferl/echo-boilerplate/util/jwt" ) +var ( + ErrBanned = errors.New("account banned") + ErrLocked = errors.New("account locked") + ErrCookieMissing = errors.New("missing access token cookie") + ErrCSRFHeaderMissing = errors.New("missing CSRF token header") + ErrCSRFInvalid = errors.New("invalid CSRF token") + ErrTokenInvalid = errors.New("token invalid") + ErrTokenMismatch = errors.New("token mismatch") + ErrTokenRevoked = errors.New("token is revoked") +) + func New() *server.Server { client, err := data.MewMongoClient() if err != nil { @@ -90,7 +101,7 @@ func newServer(userSvc handlers.UserService, patSvc handlers.PersonalAccessToken user, err := userSvc.Read(ctx, t.Subject()) if err != nil { log.Error().Err(err).Msg("failed getting user") - return echo.NewHTTPError(http.StatusInternalServerError, "Internal Server Error") + return echo.NewHTTPError(http.StatusServiceUnavailable) } c.Set("user", user) @@ -98,10 +109,10 @@ func newServer(userSvc handlers.UserService, patSvc handlers.PersonalAccessToken c.Set("roles", user.Roles) if user.IsBanned { - return echo.NewHTTPError(http.StatusForbidden, "account banned") + return echo.NewHTTPError(http.StatusForbidden, ErrBanned.Error()) } if user.IsLocked { - return echo.NewHTTPError(http.StatusForbidden, "account locked") + return echo.NewHTTPError(http.StatusForbidden, ErrLocked.Error()) } // CSRF @@ -112,21 +123,20 @@ func newServer(userSvc handlers.UserService, patSvc handlers.PersonalAccessToken default: // Validate token only for requests which are not defined as 'safe' by RFC7231 cookie, err := c.Cookie(viper.GetString(config.JWTAccessTokenCookieName)) if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "missing access token cookie") + return echo.NewHTTPError(http.StatusBadRequest, ErrCookieMissing) } h := c.Request().Header.Get(viper.GetString(config.CSRFHeaderName)) if h == "" { - return echo.NewHTTPError(http.StatusBadRequest, "missing CSRF token header") + return echo.NewHTTPError(http.StatusBadRequest, ErrCSRFHeaderMissing) } if !hash.ValidMAC([]byte(cookie.Value), []byte(h), []byte(viper.GetString(config.CSRFSecretKey))) { - return echo.NewHTTPError(http.StatusForbidden, "invalid CSRF token") + return echo.NewHTTPError(http.StatusForbidden, ErrCSRFInvalid) } } } } - // Personal Access Tokens claims := t.PrivateClaims() typ := claims["type"] @@ -136,18 +146,18 @@ func newServer(userSvc handlers.UserService, patSvc handlers.PersonalAccessToken var se *services.Error if errors.As(err, &se) { if se.Kind == services.NotExist { - return echo.NewHTTPError(http.StatusUnauthorized, "token invalid") + return echo.NewHTTPError(http.StatusUnauthorized, ErrTokenInvalid) } } - return echo.NewHTTPError(http.StatusInternalServerError, "Internal Server Error") + return echo.NewHTTPError(http.StatusServiceUnavailable) } if err = pat.Validate(encodedToken); err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "token mismatch") + return echo.NewHTTPError(http.StatusUnauthorized, ErrTokenMismatch) } if pat.IsRevoked { - return echo.NewHTTPError(http.StatusUnauthorized, "token is revoked") + return echo.NewHTTPError(http.StatusUnauthorized, ErrTokenRevoked) } } diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..b8a9e3a --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,300 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alexferl/echo-openapi" + api "github.com/alexferl/golib/http/api/server" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/alexferl/echo-boilerplate/handlers" + "github.com/alexferl/echo-boilerplate/models" + "github.com/alexferl/echo-boilerplate/services" + _ "github.com/alexferl/echo-boilerplate/testing" + "github.com/alexferl/echo-boilerplate/util/cookie" +) + +type ServerTestSuite struct { + suite.Suite + svc *handlers.MockUserService + patSvc *handlers.MockPersonalAccessTokenService + server *api.Server + user *models.User + accessToken []byte + admin *models.User +} + +func (s *ServerTestSuite) SetupTest() { + svc := handlers.NewMockUserService(s.T()) + patSvc := handlers.NewMockPersonalAccessTokenService(s.T()) + h := handlers.NewUserHandler(openapi.NewHandler(), svc) + + admin := models.NewUserWithRole("test@example.com", "test", models.AdminRole) + user := models.NewUser("test@example.com", "test") + user.Id = "1000" + user.Create(user.Id) + access, _, _ := user.Login() + + s.svc = svc + s.patSvc = patSvc + s.server = NewTestServer(svc, patSvc, h) + s.user = user + s.accessToken = access + s.admin = admin +} + +func TestServerTestSuite(t *testing.T) { + suite.Run(t, new(ServerTestSuite)) +} + +func (s *ServerTestSuite) TestServer_503() { + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.accessToken)) + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(nil, errors.New("")).Once() + + s.server.ServeHTTP(resp, req) + + s.Assert().Equal(http.StatusServiceUnavailable, resp.Code) +} + +func (s *ServerTestSuite) TestServer_403_Banned() { + _ = s.user.Ban(s.admin) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.accessToken)) + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusForbidden, resp.Code) + s.Assert().Equal(ErrBanned.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_403_Locked() { + _ = s.user.Lock(s.admin) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.accessToken)) + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusForbidden, resp.Code) + s.Assert().Equal(ErrLocked.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_400_CSRF_Header_Missing() { + req := httptest.NewRequest(http.MethodPatch, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(cookie.NewAccessToken(s.accessToken)) + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusBadRequest, resp.Code) + s.Assert().Equal(ErrCSRFHeaderMissing.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_400_CSRF_Header_Invalid() { + req := httptest.NewRequest(http.MethodPatch, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(cookie.NewAccessToken(s.accessToken)) + req.Header.Add("X-CSRF-Token", "token") + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusForbidden, resp.Code) + s.Assert().Equal(ErrCSRFInvalid.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_PAT_401_Token_Invalid() { + pat, _ := models.NewPersonalAccessToken( + s.user.Id, + fmt.Sprintf("my_token"), + time.Now().Add((7*24)*time.Hour).Format("2006-01-02"), + ) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", pat.Token)) + resp := httptest.NewRecorder() + + _ = pat.Encrypt() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.patSvc.EXPECT(). + FindOne(mock.Anything, mock.Anything, mock.Anything). + Return(nil, services.NewError(nil, services.NotExist, "")).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal(ErrTokenInvalid.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_PAT_401_Token_Mismatch() { + pat, _ := models.NewPersonalAccessToken( + s.user.Id, + fmt.Sprintf("my_token"), + time.Now().Add((7*24)*time.Hour).Format("2006-01-02"), + ) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", pat.Token)) + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.patSvc.EXPECT(). + FindOne(mock.Anything, mock.Anything, mock.Anything). + Return(pat, nil).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal(ErrTokenMismatch.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_PAT_401_Revoked() { + pat, _ := models.NewPersonalAccessToken( + s.user.Id, + fmt.Sprintf("my_token"), + time.Now().Add((7*24)*time.Hour).Format("2006-01-02"), + ) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", pat.Token)) + resp := httptest.NewRecorder() + + _ = pat.Encrypt() + pat.IsRevoked = true + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.patSvc.EXPECT(). + FindOne(mock.Anything, mock.Anything, mock.Anything). + Return(pat, nil).Once() + + s.server.ServeHTTP(resp, req) + + var result echo.HTTPError + _ = json.Unmarshal(resp.Body.Bytes(), &result) + + s.Assert().Equal(http.StatusUnauthorized, resp.Code) + s.Assert().Equal(ErrTokenRevoked.Error(), result.Message) +} + +func (s *ServerTestSuite) TestServer_PAT_503() { + pat, _ := models.NewPersonalAccessToken( + s.user.Id, + fmt.Sprintf("my_token"), + time.Now().Add((7*24)*time.Hour).Format("2006-01-02"), + ) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", pat.Token)) + resp := httptest.NewRecorder() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.patSvc.EXPECT(). + FindOne(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("")).Once() + + s.server.ServeHTTP(resp, req) + + s.Assert().Equal(http.StatusServiceUnavailable, resp.Code) +} + +func (s *ServerTestSuite) TestServer_PAT_200() { + pat, _ := models.NewPersonalAccessToken( + s.user.Id, + fmt.Sprintf("my_token"), + time.Now().Add((7*24)*time.Hour).Format("2006-01-02"), + ) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", pat.Token)) + resp := httptest.NewRecorder() + + _ = pat.Encrypt() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.patSvc.EXPECT(). + FindOne(mock.Anything, mock.Anything, mock.Anything). + Return(pat, nil).Once() + + s.svc.EXPECT(). + Read(mock.Anything, mock.Anything). + Return(s.user, nil).Once() + + s.server.ServeHTTP(resp, req) + + s.Assert().Equal(http.StatusOK, resp.Code) +} diff --git a/services/personal_access_token_test.go b/services/personal_access_token_test.go index a9551be..73db6e9 100644 --- a/services/personal_access_token_test.go +++ b/services/personal_access_token_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -38,23 +37,23 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessToken_Create() { name := "my_token" expiresAt := time.Now().Add((7 * 24) * time.Hour).Format("2006-01-02") m, err := models.NewPersonalAccessToken(s.user.Id, name, expiresAt) - assert.NoError(s.T(), err) + s.Assert().NoError(err) s.mapper.EXPECT(). Create(mock.Anything, mock.Anything). Return(m, nil) pat, err := s.svc.Create(context.Background(), m) - assert.NoError(s.T(), err) - assert.Equal(s.T(), name, pat.Name) - assert.Equal(s.T(), expiresAt, pat.ExpiresAt.Format("2006-01-02")) + s.Assert().NoError(err) + s.Assert().Equal(name, pat.Name) + s.Assert().Equal(expiresAt, pat.ExpiresAt.Format("2006-01-02")) } func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Read() { name := "my_token" expiresAt := time.Now().Add((7 * 24) * time.Hour).Format("2006-01-02") m, err := models.NewPersonalAccessToken(s.user.Id, name, expiresAt) - assert.NoError(s.T(), err) + s.Assert().NoError(err) id := "123" m.Id = id @@ -63,17 +62,17 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Read() { Return(m, nil) pat, err := s.svc.Read(context.Background(), id) - assert.NoError(s.T(), err) - assert.Equal(s.T(), id, pat.Id) - assert.Equal(s.T(), name, pat.Name) - assert.Equal(s.T(), expiresAt, pat.ExpiresAt.Format("2006-01-02")) + s.Assert().NoError(err) + s.Assert().Equal(id, pat.Id) + s.Assert().Equal(name, pat.Name) + s.Assert().Equal(expiresAt, pat.ExpiresAt.Format("2006-01-02")) } func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Read_Err() { name := "my_token" expiresAt := time.Now().Add((7 * 24) * time.Hour).Format("2006-01-02") m, err := models.NewPersonalAccessToken(s.user.Id, name, expiresAt) - assert.NoError(s.T(), err) + s.Assert().NoError(err) id := "123" m.Id = id @@ -82,11 +81,11 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Read_Err Return(nil, data.ErrNoDocuments) _, err = s.svc.Read(context.Background(), id) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.NotExist, se.Kind) + s.Assert().Equal(services.NotExist, se.Kind) } } @@ -94,7 +93,7 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Revoke() name := "my_token" expiresAt := time.Now().Add((7 * 24) * time.Hour).Format("2006-01-02") m, err := models.NewPersonalAccessToken(s.user.Id, name, expiresAt) - assert.NoError(s.T(), err) + s.Assert().NoError(err) id := "123" m.Id = id @@ -103,18 +102,18 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Revoke() Return(m, nil) err = s.svc.Revoke(context.Background(), m) - assert.NoError(s.T(), err) + s.Assert().NoError(err) s.mapper.EXPECT(). FindOne(mock.Anything, mock.Anything). Return(m, nil) pat, err := s.svc.Read(context.Background(), id) - assert.NoError(s.T(), err) - assert.True(s.T(), pat.IsRevoked) - assert.Equal(s.T(), id, pat.Id) - assert.Equal(s.T(), name, pat.Name) - assert.Equal(s.T(), expiresAt, pat.ExpiresAt.Format("2006-01-02")) + s.Assert().NoError(err) + s.Assert().True(pat.IsRevoked) + s.Assert().Equal(id, pat.Id) + s.Assert().Equal(name, pat.Name) + s.Assert().Equal(expiresAt, pat.ExpiresAt.Format("2006-01-02")) } func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Find() { @@ -123,15 +122,15 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_Find() { Return(models.PersonalAccessTokens{}, nil) pats, err := s.svc.Find(context.Background(), "123") - assert.NoError(s.T(), err) - assert.Equal(s.T(), models.PersonalAccessTokens{}, pats) + s.Assert().NoError(err) + s.Assert().Equal(models.PersonalAccessTokens{}, pats) } func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_FindOne() { name := "my_token" expiresAt := time.Now().Add((7 * 24) * time.Hour).Format("2006-01-02") m, err := models.NewPersonalAccessToken(s.user.Id, name, expiresAt) - assert.NoError(s.T(), err) + s.Assert().NoError(err) id := "123" userId := "456" m.Id = id @@ -142,18 +141,18 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_FindOne( Return(m, nil) pat, err := s.svc.FindOne(context.Background(), userId, name) - assert.NoError(s.T(), err) - assert.Equal(s.T(), id, pat.Id) - assert.Equal(s.T(), userId, pat.UserId) - assert.Equal(s.T(), name, pat.Name) - assert.Equal(s.T(), expiresAt, pat.ExpiresAt.Format("2006-01-02")) + s.Assert().NoError(err) + s.Assert().Equal(id, pat.Id) + s.Assert().Equal(userId, pat.UserId) + s.Assert().Equal(name, pat.Name) + s.Assert().Equal(expiresAt, pat.ExpiresAt.Format("2006-01-02")) } func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_FindOne_Err() { name := "my_token" expiresAt := time.Now().Add((7 * 24) * time.Hour).Format("2006-01-02") m, err := models.NewPersonalAccessToken(s.user.Id, name, expiresAt) - assert.NoError(s.T(), err) + s.Assert().NoError(err) id := "123" userId := "456" m.Id = id @@ -164,10 +163,10 @@ func (s *PersonalAccessTokenTestSuite) TestPersonalAccessTokenTestSuite_FindOne_ Return(nil, data.ErrNoDocuments) _, err = s.svc.FindOne(context.Background(), id, "") - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.NotExist, se.Kind) + s.Assert().Equal(services.NotExist, se.Kind) } } diff --git a/services/task_test.go b/services/task_test.go index 5328e86..37c0143 100644 --- a/services/task_test.go +++ b/services/task_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -39,8 +38,8 @@ func (s *TaskTestSuite) TestTask_Create() { Return(m, nil) task, err := s.svc.Create(context.Background(), id, m) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), task.CreatedBy) + s.Assert().NoError(err) + s.Assert().NotNil(task.CreatedBy) } func (s *TaskTestSuite) TestTask_Read() { @@ -53,8 +52,8 @@ func (s *TaskTestSuite) TestTask_Read() { Return(m, nil) task, err := s.svc.Read(context.Background(), id) - assert.NoError(s.T(), err) - assert.Equal(s.T(), id, task.Id) + s.Assert().NoError(err) + s.Assert().Equal(id, task.Id) } func (s *TaskTestSuite) TestTask_Read_Err() { @@ -67,11 +66,11 @@ func (s *TaskTestSuite) TestTask_Read_Err() { Return(nil, data.ErrNoDocuments) _, err := s.svc.Read(context.Background(), id) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.NotExist, se.Kind) + s.Assert().Equal(services.NotExist, se.Kind) } } @@ -85,8 +84,8 @@ func (s *TaskTestSuite) TestTask_Update() { Return(m, nil) task, err := s.svc.Update(context.Background(), id, m) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), task.UpdatedBy) + s.Assert().NoError(err) + s.Assert().NotNil(task.UpdatedBy) } func (s *TaskTestSuite) TestTask_Delete() { @@ -99,18 +98,18 @@ func (s *TaskTestSuite) TestTask_Delete() { Return(m, nil) err := s.svc.Delete(context.Background(), id, m) - assert.NoError(s.T(), err) + s.Assert().NoError(err) s.mapper.EXPECT(). FindOneById(mock.Anything, mock.Anything). Return(m, nil) _, err = s.svc.Read(context.Background(), id) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.Deleted, se.Kind) + s.Assert().Equal(services.Deleted, se.Kind) } } @@ -126,7 +125,7 @@ func (s *TaskTestSuite) TestTask_Find() { Limit: 1, Skip: 0, }) - assert.NoError(s.T(), err) - assert.Equal(s.T(), int64(1), count) - assert.Equal(s.T(), models.Tasks{}, tasks) + s.Assert().NoError(err) + s.Assert().Equal(int64(1), count) + s.Assert().Equal(models.Tasks{}, tasks) } diff --git a/services/user_test.go b/services/user_test.go index 7d39f75..0f2de86 100644 --- a/services/user_test.go +++ b/services/user_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "go.mongodb.org/mongo-driver/mongo" @@ -42,10 +41,10 @@ func (s *UserTestSuite) TestUser_Create() { Return(m, nil) user, err := s.svc.Create(context.Background(), m) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), user.CreatedBy) - assert.Equal(s.T(), email, user.Email) - assert.Equal(s.T(), username, user.Username) + s.Assert().NoError(err) + s.Assert().NotNil(user.CreatedBy) + s.Assert().Equal(email, user.Email) + s.Assert().Equal(username, user.Username) } func (s *UserTestSuite) TestUser_Create_Err() { @@ -60,11 +59,11 @@ func (s *UserTestSuite) TestUser_Create_Err() { Return(nil, &mongo.WriteError{Code: 11000}) _, err := s.svc.Create(context.Background(), m) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.Exist, se.Kind) + s.Assert().Equal(services.Exist, se.Kind) } } @@ -80,10 +79,10 @@ func (s *UserTestSuite) TestUser_Read() { Return(m, nil) user, err := s.svc.Read(context.Background(), id) - assert.NoError(s.T(), err) - assert.Equal(s.T(), id, user.Id) - assert.Equal(s.T(), email, user.Email) - assert.Equal(s.T(), username, user.Username) + s.Assert().NoError(err) + s.Assert().Equal(id, user.Id) + s.Assert().Equal(email, user.Email) + s.Assert().Equal(username, user.Username) } func (s *UserTestSuite) TestUser_Read_Err() { @@ -98,11 +97,11 @@ func (s *UserTestSuite) TestUser_Read_Err() { Return(nil, data.ErrNoDocuments) _, err := s.svc.Read(context.Background(), id) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.NotExist, se.Kind) + s.Assert().Equal(services.NotExist, se.Kind) } } @@ -118,8 +117,8 @@ func (s *UserTestSuite) TestUser_Update() { Return(m, nil) task, err := s.svc.Update(context.Background(), id, m) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), task.UpdatedBy) + s.Assert().NoError(err) + s.Assert().NotNil(task.UpdatedBy) } func (s *UserTestSuite) TestUser_Delete() { @@ -134,18 +133,18 @@ func (s *UserTestSuite) TestUser_Delete() { Return(m, nil) err := s.svc.Delete(context.Background(), id, m) - assert.NoError(s.T(), err) + s.Assert().NoError(err) s.mapper.EXPECT(). FindOne(mock.Anything, mock.Anything). Return(m, nil) _, err = s.svc.Read(context.Background(), id) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.Deleted, se.Kind) + s.Assert().Equal(services.Deleted, se.Kind) } } @@ -158,9 +157,9 @@ func (s *UserTestSuite) TestUser_Find() { Limit: 1, Skip: 0, }) - assert.NoError(s.T(), err) - assert.Equal(s.T(), int64(1), count) - assert.Equal(s.T(), models.Users{}, tasks) + s.Assert().NoError(err) + s.Assert().Equal(int64(1), count) + s.Assert().Equal(models.Users{}, tasks) } func (s *UserTestSuite) TestUser_FindOneByEmailOrUsername() { @@ -175,10 +174,10 @@ func (s *UserTestSuite) TestUser_FindOneByEmailOrUsername() { Return(m, nil) user, err := s.svc.FindOneByEmailOrUsername(context.Background(), email, username) - assert.NoError(s.T(), err) - assert.Equal(s.T(), id, user.Id) - assert.Equal(s.T(), email, user.Email) - assert.Equal(s.T(), username, user.Username) + s.Assert().NoError(err) + s.Assert().Equal(id, user.Id) + s.Assert().Equal(email, user.Email) + s.Assert().Equal(username, user.Username) } func (s *UserTestSuite) TestUser_FindOneByEmailOrUsername_Err() { @@ -193,10 +192,10 @@ func (s *UserTestSuite) TestUser_FindOneByEmailOrUsername_Err() { Return(nil, data.ErrNoDocuments) _, err := s.svc.FindOneByEmailOrUsername(context.Background(), email, username) - assert.Error(s.T(), err) + s.Assert().Error(err) var se *services.Error - assert.ErrorAs(s.T(), err, &se) + s.Assert().ErrorAs(err, &se) if errors.As(err, &se) { - assert.Equal(s.T(), services.NotExist, se.Kind) + s.Assert().Equal(services.NotExist, se.Kind) } }