Skip to content

Commit

Permalink
fix: update the wire configuration and move middleware from the handl…
Browse files Browse the repository at this point in the history
…er to the controller
  • Loading branch information
Eraxyso authored Dec 19, 2024
1 parent d33f4f9 commit 362ed96
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 102 deletions.
88 changes: 49 additions & 39 deletions handler/middleware.go → controller/middleware.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package handler
package controller

import (
"errors"
Expand All @@ -21,8 +21,18 @@ type Middleware struct {
}

// NewMiddleware Middlewareのコンストラクタ
func NewMiddleware() *Middleware {
return &Middleware{}
func NewMiddleware(
administrator model.IAdministrator,
respondent model.IRespondent,
question model.IQuestion,
questionnaire model.IQuestionnaire,
) *Middleware {
return &Middleware{
IAdministrator: administrator,
IRespondent: respondent,
IQuestion: question,
IQuestionnaire: questionnaire,
}
}

const (
Expand All @@ -41,7 +51,7 @@ const (
var adminUserIDs = []string{"ryoha", "xxarupakaxx", "kaitoyama", "cp20", "itzmeowww"}

// SetUserIDMiddleware X-Showcase-UserからユーザーIDを取得しセットする
func (*Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
func (m Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
userID := c.Request().Header.Get("X-Showcase-User")
if userID == "" {
Expand All @@ -55,9 +65,9 @@ func (*Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
}

// TraPMemberAuthenticate traP部員かの認証
func (*Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
func (m Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
userID, err := getUserID(c)
userID, err := m.GetUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -74,11 +84,11 @@ func (*Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFun
}

// TrapRateLimitMiddlewareFunc traP IDベースのリクエスト制限
func (*Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc {
func (m Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc {
config := middleware.RateLimiterConfig{
Store: middleware.NewRateLimiterMemoryStore(5),
IdentifierExtractor: func(c echo.Context) (string, error) {
userID, err := getUserID(c)
userID, err := m.GetUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -92,10 +102,10 @@ func (*Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc {
}

// QuestionnaireReadAuthenticate アンケートの閲覧権限があるかの認証
func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
func (m Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {

userID, err := getUserID(c)
userID, err := m.GetUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -116,7 +126,7 @@ func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.H
return next(c)
}
}
isAdmin, err := m.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID)
isAdmin, err := m.IAdministrator.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID)
if err != nil {
c.Logger().Errorf("failed to check questionnaire admin: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are administrator: %w", err))
Expand All @@ -127,7 +137,7 @@ func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.H
}

// 公開されたらOK
questionnaire, _, _, _, _, _, err := m.GetQuestionnaireInfo(c.Request().Context(), questionnaireID)
questionnaire, _, _, _, _, _, err := m.IQuestionnaire.GetQuestionnaireInfo(c.Request().Context(), questionnaireID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("questionnaire not found: %+v", err)
return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("questionnaire not found:%d", questionnaireID))
Expand All @@ -147,10 +157,10 @@ func (m *Middleware) QuestionnaireReadAuthenticate(next echo.HandlerFunc) echo.H
}

// QuestionnaireAdministratorAuthenticate アンケートの管理者かどうかの認証
func (m *Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
func (m Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {

userID, err := getUserID(c)
userID, err := m.GetUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -170,7 +180,7 @@ func (m *Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFun
return next(c)
}
}
isAdmin, err := m.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID)
isAdmin, err := m.IAdministrator.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID)
if err != nil {
c.Logger().Errorf("failed to check questionnaire admin: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are administrator: %w", err))
Expand All @@ -186,10 +196,10 @@ func (m *Middleware) QuestionnaireAdministratorAuthenticate(next echo.HandlerFun
}

// ResponseReadAuthenticate 回答閲覧権限があるかの認証
func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
func (m Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {

userID, err := getUserID(c)
userID, err := m.GetUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -203,7 +213,7 @@ func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.Handle
}

// 回答者ならOK
respondent, err := m.GetRespondent(c.Request().Context(), responseID)
respondent, err := m.IRespondent.GetRespondent(c.Request().Context(), responseID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("response not found: %+v", err)
return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID))
Expand All @@ -229,7 +239,7 @@ func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.Handle
}

// アンケートごとの回答閲覧権限チェック
responseReadPrivilegeInfo, err := m.GetResponseReadPrivilegeInfoByResponseID(c.Request().Context(), userID, responseID)
responseReadPrivilegeInfo, err := m.IQuestionnaire.GetResponseReadPrivilegeInfoByResponseID(c.Request().Context(), userID, responseID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("response not found: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid responseID: %d", responseID))
Expand All @@ -252,10 +262,10 @@ func (m *Middleware) ResponseReadAuthenticate(next echo.HandlerFunc) echo.Handle
}

// RespondentAuthenticate 回答者かどうかの認証
func (m *Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
func (m Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {

userID, err := getUserID(c)
userID, err := m.GetUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -268,7 +278,7 @@ func (m *Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerF
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid responseID:%s(error: %w)", strResponseID, err))
}

respondent, err := m.GetRespondent(c.Request().Context(), responseID)
respondent, err := m.IRespondent.GetRespondent(c.Request().Context(), responseID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("response not found: %+v", err)
return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID))
Expand All @@ -291,21 +301,8 @@ func (m *Middleware) RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerF
}
}

func checkResponseReadPrivilege(responseReadPrivilegeInfo *model.ResponseReadPrivilegeInfo) (bool, error) {
switch responseReadPrivilegeInfo.ResSharedTo {
case "administrators":
return responseReadPrivilegeInfo.IsAdministrator, nil
case "respondents":
return responseReadPrivilegeInfo.IsAdministrator || responseReadPrivilegeInfo.IsRespondent, nil
case "public":
return true, nil
}

return false, errors.New("invalid resSharedTo")
}

// getValidator Validatorを設定する
func getValidator(c echo.Context) (*validator.Validate, error) {
// GetValidator Validatorを設定する
func (m Middleware) GetValidator(c echo.Context) (*validator.Validate, error) {
rowValidate := c.Get(validatorKey)
validate, ok := rowValidate.(*validator.Validate)
if !ok {
Expand All @@ -315,8 +312,8 @@ func getValidator(c echo.Context) (*validator.Validate, error) {
return validate, nil
}

// getUserID ユーザーIDを取得する
func getUserID(c echo.Context) (string, error) {
// GetUserID ユーザーIDを取得する
func (m Middleware) GetUserID(c echo.Context) (string, error) {
rowUserID := c.Get(userIDKey)
userID, ok := rowUserID.(string)
if !ok {
Expand All @@ -325,3 +322,16 @@ func getUserID(c echo.Context) (string, error) {

return userID, nil
}

func checkResponseReadPrivilege(responseReadPrivilegeInfo *model.ResponseReadPrivilegeInfo) (bool, error) {
switch responseReadPrivilegeInfo.ResSharedTo {
case "administrators":
return responseReadPrivilegeInfo.IsAdministrator, nil
case "respondents":
return responseReadPrivilegeInfo.IsAdministrator || responseReadPrivilegeInfo.IsRespondent, nil
case "public":
return true, nil
}

return false, errors.New("invalid resSharedTo")
}
3 changes: 3 additions & 0 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import "github.com/traPtitech/anke-to/controller"
type Handler struct {
Questionnaire *controller.Questionnaire
Response *controller.Response
Middleware *controller.Middleware
}

func NewHandler(questionnaire *controller.Questionnaire,
response *controller.Response,
middleware *controller.Middleware,
) *Handler {
return &Handler{
Questionnaire: questionnaire,
Response: response,
Middleware: middleware,
}
}
12 changes: 6 additions & 6 deletions handler/questionnaire.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
// (GET /questionnaires)
func (h Handler) GetQuestionnaires(ctx echo.Context, params openapi.GetQuestionnairesParams) error {
res := openapi.QuestionnaireList{}
userID, err := getUserID(ctx)
userID, err := h.Middleware.GetUserID(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -35,7 +35,7 @@ func (h Handler) PostQuestionnaire(ctx echo.Context) error {
ctx.Logger().Errorf("failed to bind request body: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind request body: %w", err))
}
validate, err := getValidator(ctx)
validate, err := h.Middleware.GetValidator(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get validator: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get validator: %w", err))
Expand All @@ -48,7 +48,7 @@ func (h Handler) PostQuestionnaire(ctx echo.Context) error {
}

res := openapi.QuestionnaireDetail{}
userID, err := getUserID(ctx)
userID, err := h.Middleware.GetUserID(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand Down Expand Up @@ -136,7 +136,7 @@ func (h Handler) EditQuestionnaireMyRemindStatus(ctx echo.Context, questionnaire

// (GET /questionnaires/{questionnaireID}/responses)
func (h Handler) GetQuestionnaireResponses(ctx echo.Context, questionnaireID openapi.QuestionnaireIDInPath, params openapi.GetQuestionnaireResponsesParams) error {
userID, err := getUserID(ctx)
userID, err := h.Middleware.GetUserID(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -153,7 +153,7 @@ func (h Handler) GetQuestionnaireResponses(ctx echo.Context, questionnaireID ope
// (POST /questionnaires/{questionnaireID}/responses)
func (h Handler) PostQuestionnaireResponse(ctx echo.Context, questionnaireID openapi.QuestionnaireIDInPath) error {
res := openapi.Response{}
userID, err := getUserID(ctx)
userID, err := h.Middleware.GetUserID(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -164,7 +164,7 @@ func (h Handler) PostQuestionnaireResponse(ctx echo.Context, questionnaireID ope
ctx.Logger().Errorf("failed to bind request body: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind request body: %w", err))
}
validate, err := getValidator(ctx)
validate, err := h.Middleware.GetValidator(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get validator: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get validator: %w", err))
Expand Down
6 changes: 3 additions & 3 deletions handler/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// (GET /responses/myResponses)
func (h Handler) GetMyResponses(ctx echo.Context, params openapi.GetMyResponsesParams) error {
res := openapi.ResponsesWithQuestionnaireInfo{}
userID, err := getUserID(ctx)
userID, err := h.Middleware.GetUserID(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand All @@ -27,7 +27,7 @@ func (h Handler) GetMyResponses(ctx echo.Context, params openapi.GetMyResponsesP

// (DELETE /responses/{responseID})
func (h Handler) DeleteResponse(ctx echo.Context, responseID openapi.ResponseIDInPath) error {
userID, err := getUserID(ctx)
userID, err := h.Middleware.GetUserID(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
Expand Down Expand Up @@ -62,7 +62,7 @@ func (h Handler) EditResponse(ctx echo.Context, responseID openapi.ResponseIDInP
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind Responses: %w", err))
}

validate, err := getValidator(ctx)
validate, err := h.Middleware.GetValidator(ctx)
if err != nil {
ctx.Logger().Errorf("failed to get validator: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get validator: %w", err))
Expand Down
21 changes: 10 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,23 @@ func main() {
}
api := InjectAPIServer()
e.Use(oapiMiddleware.OapiRequestValidator(swagger))
e.Use(api.SetUserIDMiddleware)
e.Use(api.Middleware.SetUserIDMiddleware)
e.Use(middleware.Logger())
e.Use(middleware.Recover())

mws := NewMiddlewareSwitcher()
mws.AddGroupConfig("", api.TraPMemberAuthenticate)
mws.AddGroupConfig("", api.Middleware.TraPMemberAuthenticate)

mws.AddRouteConfig("/questionnaires", http.MethodGet, api.TrapRateLimitMiddlewareFunc())
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodGet, api.QuestionnaireReadAuthenticate)
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, api.QuestionnaireAdministratorAuthenticate)
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, api.QuestionnaireAdministratorAuthenticate)
mws.AddRouteConfig("/questionnaires", http.MethodGet, api.Middleware.TrapRateLimitMiddlewareFunc())
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodGet, api.Middleware.QuestionnaireReadAuthenticate)
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, api.Middleware.QuestionnaireAdministratorAuthenticate)
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, api.Middleware.QuestionnaireAdministratorAuthenticate)

mws.AddRouteConfig("/responses/:responseID", http.MethodGet, api.ResponseReadAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, api.RespondentAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, api.RespondentAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodGet, api.Middleware.ResponseReadAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, api.Middleware.RespondentAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, api.Middleware.RespondentAuthenticate)

handlerApi := InjectHandler()
openapi.RegisterHandlers(e, handlerApi)
openapi.RegisterHandlers(e, api)

e.Use(mws.ApplyMiddlewares)
e.Logger.Fatal(e.Start(port))
Expand Down
Loading

0 comments on commit 362ed96

Please sign in to comment.