-
Notifications
You must be signed in to change notification settings - Fork 28
/
user_authenticate.go
81 lines (67 loc) · 2.24 KB
/
user_authenticate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package middlewares
import (
"context"
"github.com/gofrs/uuid"
"github.com/labstack/echo/v4"
"github.com/traPtitech/traQ/repository"
"github.com/traPtitech/traQ/router/consts"
"github.com/traPtitech/traQ/router/extension/ctxkey"
"github.com/traPtitech/traQ/router/extension/herror"
"github.com/traPtitech/traQ/router/session"
)
const authScheme = "Bearer"
// UserAuthenticate リクエスト認証ミドルウェア
func UserAuthenticate(repo repository.Repository, sessStore session.Store) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
var uid uuid.UUID
if ah := c.Request().Header.Get(echo.HeaderAuthorization); len(ah) > 0 {
// AuthorizationヘッダーがあるためOAuth2で検証
// Authorizationスキーム検証
l := len(authScheme)
if !(len(ah) > l+1 && ah[:l] == authScheme) {
return herror.Unauthorized("invalid authorization scheme")
}
// OAuth2 Token検証
token, err := repo.GetTokenByAccess(ah[l+1:])
if err != nil {
switch err {
case repository.ErrNotFound:
return herror.Unauthorized("invalid token")
default:
return herror.InternalServerError(err)
}
}
// tokenの有効期限の検証
if token.IsExpired() {
return herror.Unauthorized("invalid token")
}
c.Set(consts.KeyOAuth2AccessScopes, token.Scopes)
uid = token.UserID
} else {
// Authorizationヘッダーがないためセッションを確認する
sess, err := sessStore.GetSession(c)
if err != nil {
return herror.InternalServerError(err)
}
if sess == nil || !sess.LoggedIn() {
return herror.Unauthorized("You are not logged in")
}
uid = sess.UserID()
}
// ユーザー取得
user, err := repo.GetUser(uid, true)
if err != nil {
return herror.InternalServerError(err)
}
// ユーザーアカウント状態を確認
if !user.IsActive() {
return herror.Forbidden("this account is currently suspended")
}
c.Set(consts.KeyUser, user)
c.Set(consts.KeyUserID, user.GetID())
c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), ctxkey.UserID, user.GetID()))) // SSEストリーマーで使う
return next(c)
}
}
}