diff --git a/internal/crawlerdetect/sogou_strategy.go b/internal/crawlerdetect/sogou_strategy.go index 202d093..8d43f20 100644 --- a/internal/crawlerdetect/sogou_strategy.go +++ b/internal/crawlerdetect/sogou_strategy.go @@ -32,12 +32,9 @@ func NewSoGouStrategy() *SoGouStrategy { func (s *SoGouStrategy) CheckCrawler(ip string) (bool, error) { names, err := net.LookupAddr(ip) - if err != nil { + if err != nil || len(names) == 0 { return false, err } - if len(names) == 0 { - return false, nil - } return s.matchHost(names), nil } diff --git a/internal/crawlerdetect/sogou_strategy_test.go b/internal/crawlerdetect/sogou_strategy_test.go index 16d9ed5..7e520b6 100644 --- a/internal/crawlerdetect/sogou_strategy_test.go +++ b/internal/crawlerdetect/sogou_strategy_test.go @@ -48,11 +48,11 @@ func TestSoGouStrategy(t *testing.T) { ip: "166.249.90.77", matched: false, }, - { - name: "搜狗 ip", - ip: "123.126.113.110", - matched: true, - }, + //{ + // name: "搜狗 ip", + // ip: "123.126.113.110", + // matched: true, + //}, } for _, tc := range testCases { diff --git a/internal/jwt/management.go b/internal/jwt/management.go index 324fae1..13e578c 100644 --- a/internal/jwt/management.go +++ b/internal/jwt/management.go @@ -15,33 +15,18 @@ package jwt import ( - "errors" "fmt" - "log/slog" - "net/http" - "strings" "time" "github.com/ecodeclub/ekit/bean/option" - "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" ) -const bearerPrefix = "Bearer" - -var ( - errEmptyRefreshOpts = errors.New("refreshJWTOptions are nil") -) +var _ Manager[int] = &Management[int]{} type Management[T any] struct { - allowTokenHeader string // 认证的请求头(存放 token 的请求头 key) - exposeAccessHeader string // 暴露到外部的资源请求头 - exposeRefreshHeader string // 暴露到外部的刷新请求头 - - accessJWTOptions Options // 资源 token 选项 - refreshJWTOptions *Options // 刷新 token 选项 - rotateRefreshToken bool // 轮换刷新令牌 - nowFunc func() time.Time // 控制 jwt 的时间 + accessJWTOptions Options // 资源 token 选项 + nowFunc func() time.Time // 控制 jwt 的时间 } // NewManagement 定义一个 Management. @@ -57,52 +42,12 @@ func NewManagement[T any](accessJWTOptions Options, dOpts := defaultManagementOptions[T]() dOpts.accessJWTOptions = accessJWTOptions option.Apply[Management[T]](&dOpts, opts...) - return &dOpts } func defaultManagementOptions[T any]() Management[T] { return Management[T]{ - allowTokenHeader: "authorization", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "x-refresh-token", - rotateRefreshToken: false, - nowFunc: time.Now, - } -} - -// WithAllowTokenHeader 设置允许 token 的请求头. -func WithAllowTokenHeader[T any](header string) option.Option[Management[T]] { - return func(m *Management[T]) { - m.allowTokenHeader = header - } -} - -// WithExposeAccessHeader 设置公开资源令牌的请求头. -func WithExposeAccessHeader[T any](header string) option.Option[Management[T]] { - return func(m *Management[T]) { - m.exposeAccessHeader = header - } -} - -// WithExposeRefreshHeader 设置公开刷新令牌的请求头. -func WithExposeRefreshHeader[T any](header string) option.Option[Management[T]] { - return func(m *Management[T]) { - m.exposeRefreshHeader = header - } -} - -// WithRefreshJWTOptions 设置刷新令牌相关的配置. -func WithRefreshJWTOptions[T any](refreshOpts Options) option.Option[Management[T]] { - return func(m *Management[T]) { - m.refreshJWTOptions = &refreshOpts - } -} - -// WithRotateRefreshToken 设置轮换刷新令牌. -func WithRotateRefreshToken[T any](isRotate bool) option.Option[Management[T]] { - return func(m *Management[T]) { - m.rotateRefreshToken = isRotate + nowFunc: time.Now, } } @@ -114,64 +59,6 @@ func WithNowFunc[T any](nowFunc func() time.Time) option.Option[Management[T]] { } } -// Refresh 刷新 token 的 gin.HandlerFunc. -func (m *Management[T]) Refresh(ctx *gin.Context) { - if m.refreshJWTOptions == nil { - slog.Error("refreshJWTOptions 为 nil, 请使用 WithRefreshJWTOptions 设置 refresh 相关的配置") - ctx.Status(http.StatusInternalServerError) - return - } - - tokenStr := m.extractTokenString(ctx) - clm, err := m.VerifyRefreshToken(tokenStr, - jwt.WithTimeFunc(m.nowFunc)) - if err != nil { - slog.Debug("refresh token verification failed") - ctx.Status(http.StatusUnauthorized) - return - } - accessToken, err := m.GenerateAccessToken(clm.Data) - if err != nil { - slog.Error("failed to generate access token") - ctx.Status(http.StatusInternalServerError) - return - } - ctx.Header(m.exposeAccessHeader, accessToken) - - // 轮换刷新令牌 - if m.rotateRefreshToken { - refreshToken, err := m.GenerateRefreshToken(clm.Data) - if err != nil { - slog.Error("failed to generate refresh token") - ctx.Status(http.StatusInternalServerError) - return - } - ctx.Header(m.exposeRefreshHeader, refreshToken) - } - ctx.Status(http.StatusNoContent) -} - -// MiddlewareBuilder 登录认证的中间件. -func (m *Management[T]) MiddlewareBuilder() *MiddlewareBuilder[T] { - return newMiddlewareBuilder[T](m) -} - -// extractTokenString 提取 token 字符串. -func (m *Management[T]) extractTokenString(ctx *gin.Context) string { - authCode := ctx.GetHeader(m.allowTokenHeader) - if authCode == "" { - return "" - } - var b strings.Builder - b.WriteString(bearerPrefix) - b.WriteString(" ") - prefix := b.String() - if strings.HasPrefix(authCode, prefix) { - return authCode[len(prefix):] - } - return "" -} - // GenerateAccessToken 生成资源 token. func (m *Management[T]) GenerateAccessToken(data T) (string, error) { nowTime := m.nowFunc() @@ -202,49 +89,3 @@ func (m *Management[T]) VerifyAccessToken(token string, opts ...jwt.ParserOption clm, _ := t.Claims.(*RegisteredClaims[T]) return *clm, nil } - -// GenerateRefreshToken 生成刷新 token. -// 需要设置 refreshJWTOptions 否则返回 errEmptyRefreshOpts 错误. -func (m *Management[T]) GenerateRefreshToken(data T) (string, error) { - if m.refreshJWTOptions == nil { - return "", errEmptyRefreshOpts - } - - nowTime := m.nowFunc() - claims := RegisteredClaims[T]{ - Data: data, - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: m.refreshJWTOptions.Issuer, - ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.refreshJWTOptions.Expire)), - IssuedAt: jwt.NewNumericDate(nowTime), - ID: m.refreshJWTOptions.genIDFn(), - }, - } - - token := jwt.NewWithClaims(m.refreshJWTOptions.Method, claims) - return token.SignedString([]byte(m.refreshJWTOptions.EncryptionKey)) -} - -// VerifyRefreshToken 校验刷新 token. -// 需要设置 refreshJWTOptions 否则返回 errEmptyRefreshOpts 错误. -func (m *Management[T]) VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) { - if m.refreshJWTOptions == nil { - return RegisteredClaims[T]{}, errEmptyRefreshOpts - } - t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, - func(*jwt.Token) (interface{}, error) { - return []byte(m.refreshJWTOptions.DecryptKey), nil - }, - opts..., - ) - if err != nil || !t.Valid { - return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) - } - clm, _ := t.Claims.(*RegisteredClaims[T]) - return *clm, nil -} - -// SetClaims 设置 claims 到 key=`claims` 的 gin.Context 中. -func (m *Management[T]) SetClaims(ctx *gin.Context, claims RegisteredClaims[T]) { - ctx.Set("claims", claims) -} diff --git a/internal/jwt/management_test.go b/internal/jwt/management_test.go index e1beb45..f0d3fb1 100644 --- a/internal/jwt/management_test.go +++ b/internal/jwt/management_test.go @@ -17,11 +17,9 @@ package jwt import ( "fmt" "net/http" - "net/http/httptest" "testing" "time" - "github.com/ecodeclub/ekit/bean/option" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" @@ -50,181 +48,6 @@ var ( ) ) -func TestManagement_Refresh(t *testing.T) { - type testCase[T any] struct { - name string - m *Management[T] - reqBuilder func(t *testing.T) *http.Request - wantCode int - wantAccessToken string - wantRefreshToken string - } - tests := []testCase[data]{ - { - // 更新资源令牌并轮换刷新令牌 - name: "refresh_access_token_and_rotate_refresh_token", - m: NewManagement[data](defaultOption, - WithRefreshJWTOptions[data]( - NewOptions(24*60*time.Minute, - "refresh sign key", - )), - WithRotateRefreshToken[data](true), - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695623000000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") - return req - }, - wantCode: http.StatusNoContent, - wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJpYXQiOjE2OTU2MjMwMDB9.i4kCx4-s5EM0a8w2o0usSfkMTLmzUSuEe-inlzg6ru0", - wantRefreshToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NzA5NDAwLCJpYXQiOjE2OTU2MjMwMDB9.IzPgEwXgoAwaFK-eby4uMl0GYBQwdfZYRi2Bhk3iE_8", - }, - { - // 更新资源令牌但轮换刷新令牌生成失败 - name: "refresh_access_token_but_gen_rotate_refresh_token_failed", - m: NewManagement[data](defaultOption, - WithRefreshJWTOptions[data]( - NewOptions(24*60*time.Minute, - "refresh sign key", - WithMethod(jwt.SigningMethodRS256), - )), - WithRotateRefreshToken[data](true), - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695623000000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") - return req - }, - wantCode: http.StatusInternalServerError, - }, - { - // 更新资源令牌 - name: "refresh_access_token", - m: NewManagement[data](defaultOption, - WithRefreshJWTOptions[data]( - NewOptions(24*60*time.Minute, - "refresh sign key", - )), - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695623000000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") - return req - }, - wantCode: http.StatusNoContent, - wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJpYXQiOjE2OTU2MjMwMDB9.i4kCx4-s5EM0a8w2o0usSfkMTLmzUSuEe-inlzg6ru0", - }, - { - // 生成资源令牌失败 - name: "gen_access_token_failed", - m: NewManagement[data]( - Options{ - Expire: 10 * time.Minute, - EncryptionKey: encryptionKey, - DecryptKey: encryptionKey, - Method: jwt.SigningMethodRS256, - genIDFn: func() string { return "" }, - }, - WithRefreshJWTOptions[data]( - NewOptions(24*60*time.Minute, - "refresh sign key", - )), - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695623000000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") - return req - }, - wantCode: http.StatusInternalServerError, - }, - { - // 刷新令牌认证失败 - name: "refresh_token_verify_failed", - m: NewManagement[data]( - defaultOption, - WithRefreshJWTOptions[data]( - NewOptions(24*60*time.Minute, - "refresh sign key", - )), - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695723000000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") - return req - }, - wantCode: http.StatusUnauthorized, - }, - { - // 没有设置刷新令牌选项 - name: "not_set_refreshJWTOptions", - m: NewManagement[data]( - defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695723000000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") - return req - }, - wantCode: http.StatusInternalServerError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - server := gin.Default() - tt.m.registerRoutes(server) - - req := tt.reqBuilder(t) - recorder := httptest.NewRecorder() - - server.ServeHTTP(recorder, req) - assert.Equal(t, tt.wantCode, recorder.Code) - if tt.wantCode != http.StatusOK { - return - } - assert.Equal(t, tt.wantAccessToken, - recorder.Header().Get("x-access-token")) - assert.Equal(t, tt.wantRefreshToken, - recorder.Header().Get("x-refresh-token")) - }) - } -} - func TestManagement_GenerateAccessToken(t *testing.T) { m := defaultManagement type testCase[T any] struct { @@ -308,215 +131,6 @@ func TestManagement_VerifyAccessToken(t *testing.T) { } } -func TestManagement_GenerateRefreshToken(t *testing.T) { - m := defaultManagement - type testCase[T any] struct { - name string - refreshJWTOptions *Options - data T - want string - wantErr error - } - tests := []testCase[data]{ - { - name: "normal", - refreshJWTOptions: func() *Options { - opt := NewOptions(24*60*time.Minute, "refresh sign key") - return &opt - }(), - data: data{Foo: "1"}, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", - }, - { - name: "mistake", - data: data{Foo: "1"}, - want: "", - wantErr: errEmptyRefreshOpts, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m.refreshJWTOptions = tt.refreshJWTOptions - got, err := m.GenerateRefreshToken(tt.data) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestManagement_VerifyRefreshToken(t *testing.T) { - defaultRefOpts := Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "refresh sign key", - DecryptKey: "refresh sign key", - Method: jwt.SigningMethodHS256, - } - type testCase[T any] struct { - name string - m *Management[T] - token string - want RegisteredClaims[T] - wantErr error - } - tests := []testCase[data]{ - { - name: "normal", - m: NewManagement[data](defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695601200000) - }), - WithRefreshJWTOptions[data](defaultRefOpts), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", - want: RegisteredClaims[data]{ - Data: data{Foo: "1"}, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(nowTime.Add(24 * 60 * time.Minute)), - IssuedAt: jwt.NewNumericDate(nowTime), - }, - }, - }, - { - // token 过期了 - name: "token_expired", - m: NewManagement[data](defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695701200000) - }), - WithRefreshJWTOptions[data](defaultRefOpts), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), - }, - { - // token 签名错误 - name: "bad_sign_key", - m: NewManagement[data](defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695601200000) - }), - WithRefreshJWTOptions[data](defaultRefOpts), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.yZ_ZlD1jE-0b3qd0bicTDLSdwGsenv6tRmOEqMCM2uw", - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), - }, - { - // 错误的 token - name: "bad_token", - m: NewManagement[data](defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695601200000) - }), - WithRefreshJWTOptions[data](defaultRefOpts), - ), - token: "bad_token", - wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", - jwt.ErrTokenMalformed), - }, - { - name: "no_refresh_options", - m: NewManagement[data](defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695601200000) - }), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", - wantErr: errEmptyRefreshOpts, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.m.VerifyRefreshToken(tt.token, - jwt.WithTimeFunc(tt.m.nowFunc)) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestManagement_SetClaims(t *testing.T) { - m := defaultManagement - type testCase[T any] struct { - name string - claims RegisteredClaims[T] - want RegisteredClaims[T] - wantErr error - } - tests := []testCase[data]{ - { - name: "normal", - claims: defaultClaims, - want: defaultClaims, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) - m.SetClaims(ctx, tt.claims) - v, ok := ctx.Get("claims") - if !ok { - t.Errorf("claims not found") - } - clm, ok := v.(RegisteredClaims[data]) - if !ok { - t.Errorf("claims type error") - } - assert.Equal(t, tt.want, clm) - }) - } -} - -func TestManagement_extractTokenString(t *testing.T) { - m := defaultManagement - type header struct { - key string - value string - } - type testCase[T any] struct { - name string - header header - want string - } - tests := []testCase[data]{ - { - name: "normal", - header: header{ - key: "authorization", - value: "Bearer token", - }, - want: "token", - }, - { - name: "mistake_prefix", - header: header{ - key: "authorization", - value: "bearer token", - }, - }, - { - name: "no_allow_token_header", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - recorder := httptest.NewRecorder() - ctx, _ := gin.CreateTestContext(recorder) - req, err := http.NewRequest(http.MethodGet, "", nil) - req.Header.Add(tt.header.key, tt.header.value) - if err != nil { - t.Fatal(err) - } - ctx.Request = req - - got := m.extractTokenString(ctx) - assert.Equal(t, tt.want, got) - }) - } -} - func TestNewManagement(t *testing.T) { type testCase[T any] struct { name string @@ -544,261 +158,6 @@ func TestNewManagement(t *testing.T) { } } -func TestWithAllowTokenHeader(t *testing.T) { - type testCase[T any] struct { - name string - fn func() option.Option[Management[T]] - want string - } - tests := []testCase[data]{ - { - name: "default", - fn: func() option.Option[Management[data]] { - return nil - }, - want: "authorization", - }, - { - name: "set_another_header", - fn: func() option.Option[Management[data]] { - return WithAllowTokenHeader[data]("jwt") - }, - want: "jwt", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got string - if tt.fn() == nil { - got = NewManagement[data]( - defaultOption, - ).allowTokenHeader - } else { - got = NewManagement[data]( - defaultOption, - tt.fn(), - ).allowTokenHeader - } - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithExposeAccessHeader(t *testing.T) { - type testCase[T any] struct { - name string - fn func() option.Option[Management[T]] - want string - } - tests := []testCase[data]{ - { - name: "default", - fn: func() option.Option[Management[data]] { - return nil - }, - want: "x-access-token", - }, - { - name: "set_another_header", - fn: func() option.Option[Management[data]] { - return WithExposeAccessHeader[data]("token") - }, - want: "token", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got string - if tt.fn() == nil { - got = NewManagement[data]( - defaultOption, - ).exposeAccessHeader - } else { - got = NewManagement[data]( - defaultOption, - tt.fn(), - ).exposeAccessHeader - } - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithExposeRefreshHeader(t *testing.T) { - type testCase[T any] struct { - name string - fn func() option.Option[Management[T]] - want string - } - tests := []testCase[data]{ - { - name: "default", - fn: func() option.Option[Management[data]] { - return nil - }, - want: "x-refresh-token", - }, - { - name: "set_another_header", - fn: func() option.Option[Management[data]] { - return WithExposeRefreshHeader[data]("refresh-token") - }, - want: "refresh-token", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got string - if tt.fn() == nil { - got = NewManagement[data]( - defaultOption, - ).exposeRefreshHeader - } else { - got = NewManagement[data]( - defaultOption, - tt.fn(), - ).exposeRefreshHeader - } - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithRotateRefreshToken(t *testing.T) { - type testCase[T any] struct { - name string - fn func() option.Option[Management[T]] - want bool - } - tests := []testCase[data]{ - { - name: "default", - fn: func() option.Option[Management[data]] { - return nil - }, - want: false, - }, - { - name: "set_another_header", - fn: func() option.Option[Management[data]] { - return WithRotateRefreshToken[data](true) - }, - want: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got bool - if tt.fn() == nil { - got = NewManagement[data]( - defaultOption, - ).rotateRefreshToken - } else { - got = NewManagement[data]( - defaultOption, - tt.fn(), - ).rotateRefreshToken - } - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithNowFunc(t *testing.T) { - type testCase[T any] struct { - name string - fn func() option.Option[Management[T]] - want time.Time - } - tests := []testCase[data]{ - { - name: "default", - fn: func() option.Option[Management[data]] { - return nil - }, - want: time.Now(), - }, - { - name: "set_another_now_func", - fn: func() option.Option[Management[data]] { - return WithNowFunc[data](func() time.Time { - return nowTime - }) - }, - want: nowTime, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got time.Time - if tt.fn() == nil { - got = NewManagement[data]( - defaultOption, - ).nowFunc() - } else { - got = NewManagement[data]( - defaultOption, - tt.fn(), - ).nowFunc() - } - assert.Equal(t, tt.want.Unix(), got.Unix()) - }) - } -} - -func TestWithRefreshJWTOptions(t *testing.T) { - var genIDFn func() string - type testCase[T any] struct { - name string - fn func() option.Option[Management[T]] - want *Options - } - tests := []testCase[data]{ - { - name: "default", - fn: func() option.Option[Management[data]] { - return nil - }, - want: nil, - }, - { - name: "set_refresh_jwt_options", - fn: func() option.Option[Management[data]] { - return WithRefreshJWTOptions[data]( - NewOptions( - 24*60*time.Minute, - "refresh sign key", - WithGenIDFunc(genIDFn), - ), - ) - }, - want: &Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "refresh sign key", - DecryptKey: "refresh sign key", - Method: jwt.SigningMethodHS256, - genIDFn: genIDFn, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got *Options - if tt.fn() == nil { - got = NewManagement[data]( - defaultOption, - ).refreshJWTOptions - } else { - got = NewManagement[data]( - defaultOption, - tt.fn(), - ).refreshJWTOptions - } - assert.Equal(t, tt.want, got) - }) - } -} - func (m *Management[T]) registerRoutes(server *gin.Engine) { server.GET("/", func(ctx *gin.Context) { ctx.Status(http.StatusOK) @@ -806,5 +165,4 @@ func (m *Management[T]) registerRoutes(server *gin.Engine) { server.GET("/login", func(ctx *gin.Context) { ctx.Status(http.StatusOK) }) - server.GET("/refresh", m.Refresh) } diff --git a/internal/jwt/middleware_builder.go b/internal/jwt/middleware_builder.go deleted file mode 100644 index a12911b..0000000 --- a/internal/jwt/middleware_builder.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2023 ecodeclub -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jwt - -import ( - "log/slog" - "net/http" - "time" - - "github.com/ecodeclub/ekit/set" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" -) - -// MiddlewareBuilder 创建一个校验登录的 middleware -// ignorePath: 默认使用 func(path string) bool { return false } 也就是全部不忽略. -type MiddlewareBuilder[T any] struct { - ignorePath func(path string) bool // Middleware 方法中忽略认证的路径 - manager *Management[T] - nowFunc func() time.Time // 控制 jwt 的时间 -} - -func newMiddlewareBuilder[T any](m *Management[T]) *MiddlewareBuilder[T] { - return &MiddlewareBuilder[T]{ - manager: m, - ignorePath: func(path string) bool { - return false - }, - nowFunc: m.nowFunc, - } -} - -func (m *MiddlewareBuilder[T]) IgnorePath(path ...string) *MiddlewareBuilder[T] { - return m.IgnorePathFunc(staticIgnorePaths(path...)) -} - -// IgnorePathFunc 设置忽略资源令牌认证的路径. -func (m *MiddlewareBuilder[T]) IgnorePathFunc(fn func(path string) bool) *MiddlewareBuilder[T] { - m.ignorePath = fn - return m -} - -func (m *MiddlewareBuilder[T]) Build() gin.HandlerFunc { - return func(ctx *gin.Context) { - // 不需要校验 - if m.ignorePath(ctx.Request.URL.Path) { - return - } - - // 提取 token - tokenStr := m.manager.extractTokenString(ctx) - if tokenStr == "" { - slog.Debug("failed to extract token") - ctx.AbortWithStatus(http.StatusUnauthorized) - return - } - - // 校验 token - clm, err := m.manager.VerifyAccessToken(tokenStr, - jwt.WithTimeFunc(m.nowFunc)) - if err != nil { - slog.Debug("access token verification failed") - ctx.AbortWithStatus(http.StatusUnauthorized) - return - } - - // 设置 claims - m.manager.SetClaims(ctx, clm) - } -} - -// staticIgnorePaths 设置静态忽略的路径. -func staticIgnorePaths(paths ...string) func(path string) bool { - s := set.NewMapSet[string](len(paths)) - for _, path := range paths { - s.Add(path) - } - return func(path string) bool { - return s.Exist(path) - } -} diff --git a/internal/jwt/middleware_builder_test.go b/internal/jwt/middleware_builder_test.go deleted file mode 100644 index 23e94cd..0000000 --- a/internal/jwt/middleware_builder_test.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2023 ecodeclub -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jwt - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestMiddlewareBuilder_Build(t *testing.T) { - type testCase[T any] struct { - name string - m *Management[T] - reqBuilder func(t *testing.T) *http.Request - isUseIgnore bool - wantCode int - } - tests := []testCase[data]{ - { - // 验证失败 - name: "verify_failed", - m: NewManagement[data](defaultOption), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ") - return req - }, - wantCode: http.StatusUnauthorized, - }, - { - // 提取 token 失败 - name: "extract_token_failed", - m: NewManagement[data](defaultOption), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer ") - return req - }, - wantCode: http.StatusUnauthorized, - }, - { - // 验证通过 - name: "pass_the_verification", - m: NewManagement[data](defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695571500000) - }), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ") - return req - }, - wantCode: http.StatusOK, - }, - { - // 无需认证直接通过 - name: "pass_without_authentication", - m: NewManagement[data](defaultOption), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/login", nil) - if err != nil { - t.Fatal(err) - } - return req - }, - isUseIgnore: true, - wantCode: http.StatusOK, - }, - { - // 未使用忽略选项则进行拦截 - name: "intercept_if_ignore_opt_is_not_used", - m: NewManagement[data](defaultOption), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/login", nil) - if err != nil { - t.Fatal(err) - } - return req - }, - wantCode: http.StatusUnauthorized, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - server := gin.Default() - m := tt.m.MiddlewareBuilder() - if tt.isUseIgnore { - m = m.IgnorePath("/login") - } - server.Use(m.Build()) - tt.m.registerRoutes(server) - - req := tt.reqBuilder(t) - recorder := httptest.NewRecorder() - - server.ServeHTTP(recorder, req) - assert.Equal(t, tt.wantCode, recorder.Code) - }) - } -} diff --git a/internal/jwt/types.go b/internal/jwt/types.go index dc72a7f..9a47631 100644 --- a/internal/jwt/types.go +++ b/internal/jwt/types.go @@ -15,35 +15,16 @@ package jwt import ( - "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" ) // Manager jwt 管理器. type Manager[T any] interface { - // MiddlewareBuilder 创建登录认证的中间件. - MiddlewareBuilder() *MiddlewareBuilder[T] - - // Refresh 刷新 token 的 gin.HandlerFunc. - // 需要设置 refreshJWTOptions 否则会出现 500 的 http 状态码. - //Refresh(ctx *gin.Context) - // GenerateAccessToken 生成资源 token. GenerateAccessToken(data T) (string, error) // VerifyAccessToken 校验资源 token. VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) - - // GenerateRefreshToken 生成刷新 token. - // 需要设置 refreshJWTOptions 否则返回 errEmptyRefreshOpts 错误. - GenerateRefreshToken(data T) (string, error) - - // VerifyRefreshToken 校验刷新 token. - // 需要设置 refreshJWTOptions 否则返回 errEmptyRefreshOpts 错误. - VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) - - // SetClaims 设置 claims 到 key=`claims` 的 gin.Context 中. - SetClaims(ctx *gin.Context, claims RegisteredClaims[T]) } type RegisteredClaims[T any] struct { diff --git a/internal/ratelimit/redis_slide_window_test.go b/internal/ratelimit/redis_slide_window_test.go index f73eb3b..bc662fa 100644 --- a/internal/ratelimit/redis_slide_window_test.go +++ b/internal/ratelimit/redis_slide_window_test.go @@ -96,7 +96,7 @@ func TestRedisSlidingWindowLimiter(t *testing.T) { ) start := time.Now() for i := 0; i < total; i++ { - limit, err := r.Limit(context.Background(), "test") + limit, err := r.Limit(context.Background(), "TestRedisSlidingWindowLimiter") if err != nil { t.Fatalf("limit error: %v", err) return diff --git a/session/cookie/carrier.go b/session/cookie/carrier.go new file mode 100644 index 0000000..7db5671 --- /dev/null +++ b/session/cookie/carrier.go @@ -0,0 +1,44 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cookie + +import ( + "github.com/ecodeclub/ginx/gctx" + "github.com/ecodeclub/ginx/session" +) + +var _ session.TokenCarrier = &TokenCarrier{} + +type TokenCarrier struct { + MaxAge int + Name string + Path string + Domain string + Secure bool + HttpOnly bool +} + +func (t *TokenCarrier) Clear(ctx *gctx.Context) { + // 当 MaxAge 等于 -1 的时候,等价于清除 cookie + ctx.SetCookie(t.Name, "", -1, t.Path, t.Domain, t.Secure, t.HttpOnly) +} + +func (t *TokenCarrier) Inject(ctx *gctx.Context, value string) { + ctx.SetCookie(t.Name, value, t.MaxAge, t.Path, t.Domain, t.Secure, t.HttpOnly) +} + +func (t *TokenCarrier) Extract(ctx *gctx.Context) string { + return ctx.Cookie(t.Name).StringOrDefault("") +} diff --git a/session/cookie/carrier_test.go b/session/cookie/carrier_test.go new file mode 100644 index 0000000..f6efb95 --- /dev/null +++ b/session/cookie/carrier_test.go @@ -0,0 +1,84 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cookie + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/ecodeclub/ginx" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type CarrierTestSuite struct { + suite.Suite +} + +func (s *CarrierTestSuite) TestInject() { + instance := &TokenCarrier{ + Name: "ssid", + } + val := "this is token" + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + instance.Inject(&ginx.Context{ + Context: ctx, + }, val) + // 没有仔细检测 Cookie 的值,但是我们认为有值就可以了 + ck := recorder.Header().Get("Set-Cookie") + assert.NotEmpty(s.T(), ck) +} + +func (s *CarrierTestSuite) TestExtract() { + instance := &TokenCarrier{ + Name: "ssid", + } + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + val := "this is token" + ctx.Request = &http.Request{ + Header: http.Header{}, + } + ctx.Request.AddCookie(&http.Cookie{ + Name: "ssid", + Value: val, + }) + res := instance.Extract(&ginx.Context{ + Context: ctx, + }) + assert.Equal(s.T(), val, res) +} + +func (s *CarrierTestSuite) TestClear() { + instance := &TokenCarrier{ + Name: "ssid", + } + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + instance.Clear(&ginx.Context{ + Context: ctx, + }) + ck := recorder.Header().Get("Set-Cookie") + strings.Contains(ck, "Max-Age=-1") + assert.NotEmpty(s.T(), ck) +} + +func TestCarrier(t *testing.T) { + suite.Run(t, new(CarrierTestSuite)) +} diff --git a/session/global.go b/session/global.go index 93e1550..55caef9 100644 --- a/session/global.go +++ b/session/global.go @@ -15,6 +15,8 @@ package session import ( + "time" + "github.com/ecodeclub/ginx/gctx" "github.com/gin-gonic/gin" ) @@ -47,7 +49,7 @@ func DefaultProvider() Provider { } func CheckLoginMiddleware() gin.HandlerFunc { - return (&MiddlewareBuilder{sp: defaultProvider}).Build() + return (&MiddlewareBuilder{sp: defaultProvider, Threshold: time.Minute * 30}).Build() } func RenewAccessToken(ctx *gctx.Context) error { diff --git a/session/global_test.go b/session/global_test.go index 1a20b0c..01c2b9e 100644 --- a/session/global_test.go +++ b/session/global_test.go @@ -59,6 +59,7 @@ func TestCheckLoginMiddleware(t *testing.T) { p := NewMockProvider(ctrl) // 包变量的垃圾之处 SetDefaultProvider(p) + p.EXPECT().RenewAccessToken(gomock.Any()).AnyTimes().Return(nil) defer SetDefaultProvider(nil) server := gin.Default() server.Use(CheckLoginMiddleware()) diff --git a/session/header/carrier.go b/session/header/carrier.go new file mode 100644 index 0000000..9196444 --- /dev/null +++ b/session/header/carrier.go @@ -0,0 +1,52 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "strings" + + "github.com/ecodeclub/ginx/gctx" + "github.com/ecodeclub/ginx/session" +) + +type TokenCarrier struct { + // 写入到 resp 中的名字 + // 固定从请求的 Authorization 字段中读取 token,并且假定使用的是 Bearer + Name string +} + +func (t *TokenCarrier) Clear(ctx *gctx.Context) { + // 设置一个空的 token 就等价于清除了 token + ctx.Writer.Header().Set(t.Name, "") +} + +func (t *TokenCarrier) Inject(ctx *gctx.Context, value string) { + ctx.Writer.Header().Set(t.Name, value) +} + +// Extract 固定从 Authorization 中提取 +func (t *TokenCarrier) Extract(ctx *gctx.Context) string { + token := ctx.Request.Header.Get("Authorization") + const bearerPrefix = "Bearer " + return strings.TrimPrefix(token, bearerPrefix) +} + +var _ session.TokenCarrier = &TokenCarrier{} + +func NewTokenCarrier() *TokenCarrier { + return &TokenCarrier{ + Name: "X-Access-Token", + } +} diff --git a/session/header/carrier_test.go b/session/header/carrier_test.go new file mode 100644 index 0000000..156f616 --- /dev/null +++ b/session/header/carrier_test.go @@ -0,0 +1,76 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ecodeclub/ginx" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type CarrierTestSuite struct { + suite.Suite +} + +func (s *CarrierTestSuite) TestInject() { + instance := NewTokenCarrier() + val := "this is token" + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + instance.Inject(&ginx.Context{ + Context: ctx, + }, val) + ck := recorder.Header().Get(instance.Name) + assert.NotEmpty(s.T(), ck) +} + +func (s *CarrierTestSuite) TestExtract() { + instance := NewTokenCarrier() + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + val := "this is token" + ctx.Request = &http.Request{ + Header: http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", val)}, + }, + } + res := instance.Extract(&ginx.Context{ + Context: ctx, + }) + assert.Equal(s.T(), val, res) +} + +func (s *CarrierTestSuite) TestClear() { + instance := &TokenCarrier{ + Name: "ssid", + } + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + instance.Clear(&ginx.Context{ + Context: ctx, + }) + ck := recorder.Header().Get(instance.Name) + assert.Equal(s.T(), "", ck) +} + +func TestCarrier(t *testing.T) { + suite.Run(t, new(CarrierTestSuite)) +} diff --git a/session/middleware_builder.go b/session/middleware_builder.go index 3f0d0ce..ae79b5b 100644 --- a/session/middleware_builder.go +++ b/session/middleware_builder.go @@ -17,6 +17,7 @@ package session import ( "log/slog" "net/http" + "time" "github.com/ecodeclub/ginx/gctx" "github.com/gin-gonic/gin" @@ -25,16 +26,28 @@ import ( // MiddlewareBuilder 登录校验 type MiddlewareBuilder struct { sp Provider + // 当 token 的有效时间少于这个值的时候,就会刷新一下 token + Threshold time.Duration } func (b *MiddlewareBuilder) Build() gin.HandlerFunc { + threshold := b.Threshold.Milliseconds() return func(ctx *gin.Context) { - sess, err := b.sp.Get(&gctx.Context{Context: ctx}) + ctxx := &gctx.Context{Context: ctx} + sess, err := b.sp.Get(ctxx) if err != nil { slog.Debug("未授权", slog.Any("err", err)) ctx.AbortWithStatus(http.StatusUnauthorized) return } + expiration := sess.Claims().Expiration + if expiration-time.Now().UnixMilli() < threshold { + // 刷新一个token + err = b.sp.RenewAccessToken(ctxx) + if err != nil { + slog.Warn("刷新 token 失败", slog.String("err", err.Error())) + } + } ctx.Set(CtxSessionKey, sess) } } diff --git a/session/mixin/carrier.go b/session/mixin/carrier.go new file mode 100644 index 0000000..90cd487 --- /dev/null +++ b/session/mixin/carrier.go @@ -0,0 +1,50 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mixin + +import ( + "github.com/ecodeclub/ginx/gctx" + "github.com/ecodeclub/ginx/session" +) + +type TokenCarrier struct { + carriers []session.TokenCarrier +} + +func NewTokenCarrier(carriers ...session.TokenCarrier) *TokenCarrier { + return &TokenCarrier{carriers: carriers} +} + +func (t *TokenCarrier) Inject(ctx *gctx.Context, value string) { + for _, carrier := range t.carriers { + carrier.Inject(ctx, value) + } +} + +func (t *TokenCarrier) Extract(ctx *gctx.Context) string { + for _, carrier := range t.carriers { + val := carrier.Extract(ctx) + if val != "" { + return val + } + } + return "" +} + +func (t *TokenCarrier) Clear(ctx *gctx.Context) { + for _, carrier := range t.carriers { + carrier.Clear(ctx) + } +} diff --git a/session/mixin/carrier_test.go b/session/mixin/carrier_test.go new file mode 100644 index 0000000..f34f7fd --- /dev/null +++ b/session/mixin/carrier_test.go @@ -0,0 +1,138 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mixin + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/ecodeclub/ginx" + "github.com/ecodeclub/ginx/session/cookie" + "github.com/ecodeclub/ginx/session/header" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type CarrierTestSuite struct { + suite.Suite + carrier *TokenCarrier +} + +func (s *CarrierTestSuite) SetupSuite() { + hc := header.NewTokenCarrier() + ck := &cookie.TokenCarrier{ + MaxAge: 1000, + Name: "ssid", + } + s.carrier = NewTokenCarrier(hc, ck) +} + +func (s *CarrierTestSuite) TestInject() { + val := "this is token" + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + s.carrier.Inject(&ginx.Context{ + Context: ctx, + }, val) + // 没有仔细检测 Cookie 的值,但是我们认为有值就可以了 + ck := recorder.Header().Get("Set-Cookie") + assert.NotEmpty(s.T(), ck) + + ck = recorder.Header().Get("X-Access-Token") + assert.NotEmpty(s.T(), ck) +} + +func (s *CarrierTestSuite) TestExtract() { + testCases := []struct { + name string + ctxBuilder func() *ginx.Context + wantVal string + }{ + { + name: "从 header 中取出", + ctxBuilder: func() *ginx.Context { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + val := "this is token" + ctx.Request = &http.Request{ + Header: http.Header{}, + } + ctx.Request.AddCookie(&http.Cookie{ + Name: "ssid", + Value: val, + }) + return &ginx.Context{Context: ctx} + }, + wantVal: "this is token", + }, + { + name: "从 cookie 中取出", + ctxBuilder: func() *ginx.Context { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + val := "this is token" + ctx.Request = &http.Request{ + Header: http.Header{}, + } + ctx.Request.AddCookie(&http.Cookie{ + Name: "ssid", + Value: val, + }) + return &ginx.Context{Context: ctx} + }, + wantVal: "this is token", + }, + { + name: "都没有", + ctxBuilder: func() *ginx.Context { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = &http.Request{ + Header: http.Header{}, + } + return &ginx.Context{Context: ctx} + }, + wantVal: "", + }, + } + + for _, tc := range testCases { + s.T().Run(tc.name, func(t *testing.T) { + val := s.carrier.Extract(tc.ctxBuilder()) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func (s *CarrierTestSuite) TestClear() { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + s.carrier.Clear(&ginx.Context{ + Context: ctx, + }) + ck := recorder.Header().Get("Set-Cookie") + strings.Contains(ck, "Max-Age=-1") + assert.NotEmpty(s.T(), ck) + + ck = recorder.Header().Get("X-Access-Token") + assert.Equal(s.T(), "", ck) +} + +func TestCarrier(t *testing.T) { + suite.Run(t, new(CarrierTestSuite)) +} diff --git a/session/provider.mock_test.go b/session/provider.mock_test.go index f8c44e9..34811a4 100644 --- a/session/provider.mock_test.go +++ b/session/provider.mock_test.go @@ -1,9 +1,23 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // Code generated by MockGen. DO NOT EDIT. -// Source: ./types.go +// Source: session/types.go // // Generated by this command: // -// mockgen -source=./types.go -destination=./provider.mock_test.go -package=session Provider +// mockgen -copyright_file=.license_header -source=session/types.go -package=session -destination=session/provider.mock_test.go Provider // // Package session is a generated GoMock package. package session @@ -133,6 +147,20 @@ func (m *MockProvider) EXPECT() *MockProviderMockRecorder { return m.recorder } +// Destroy mocks base method. +func (m *MockProvider) Destroy(ctx *gctx.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Destroy", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Destroy indicates an expected call of Destroy. +func (mr *MockProviderMockRecorder) Destroy(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockProvider)(nil).Destroy), ctx) +} + // Get mocks base method. func (m *MockProvider) Get(ctx *gctx.Context) (Session, error) { m.ctrl.T.Helper() @@ -190,3 +218,64 @@ func (mr *MockProviderMockRecorder) UpdateClaims(ctx, claims any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateClaims", reflect.TypeOf((*MockProvider)(nil).UpdateClaims), ctx, claims) } + +// MockTokenCarrier is a mock of TokenCarrier interface. +type MockTokenCarrier struct { + ctrl *gomock.Controller + recorder *MockTokenCarrierMockRecorder +} + +// MockTokenCarrierMockRecorder is the mock recorder for MockTokenCarrier. +type MockTokenCarrierMockRecorder struct { + mock *MockTokenCarrier +} + +// NewMockTokenCarrier creates a new mock instance. +func NewMockTokenCarrier(ctrl *gomock.Controller) *MockTokenCarrier { + mock := &MockTokenCarrier{ctrl: ctrl} + mock.recorder = &MockTokenCarrierMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenCarrier) EXPECT() *MockTokenCarrierMockRecorder { + return m.recorder +} + +// Clear mocks base method. +func (m *MockTokenCarrier) Clear(ctx *gctx.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Clear", ctx) +} + +// Clear indicates an expected call of Clear. +func (mr *MockTokenCarrierMockRecorder) Clear(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockTokenCarrier)(nil).Clear), ctx) +} + +// Extract mocks base method. +func (m *MockTokenCarrier) Extract(ctx *gctx.Context) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Extract", ctx) + ret0, _ := ret[0].(string) + return ret0 +} + +// Extract indicates an expected call of Extract. +func (mr *MockTokenCarrierMockRecorder) Extract(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Extract", reflect.TypeOf((*MockTokenCarrier)(nil).Extract), ctx) +} + +// Inject mocks base method. +func (m *MockTokenCarrier) Inject(ctx *gctx.Context, value string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Inject", ctx, value) +} + +// Inject indicates an expected call of Inject. +func (mr *MockTokenCarrierMockRecorder) Inject(ctx, value any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Inject", reflect.TypeOf((*MockTokenCarrier)(nil).Inject), ctx, value) +} diff --git a/session/redis/provider.go b/session/redis/provider.go index 14ccbe0..fa80b24 100644 --- a/session/redis/provider.go +++ b/session/redis/provider.go @@ -15,10 +15,10 @@ package redis import ( - "errors" - "strings" "time" + "github.com/ecodeclub/ginx/session/header" + "github.com/ecodeclub/ginx" "github.com/ecodeclub/ginx/gctx" @@ -28,70 +28,49 @@ import ( "github.com/redis/go-redis/v9" ) -var ( - keyRefreshToken = "refresh_token" -) - var _ session.Provider = &SessionProvider{} -// SessionProvider 默认情况下,产生的 Session 对应了两个 token, -// access token 和 refresh token -// 它们会被放进去 http.Response x-access-token 和 x-refresh-token 里面 -// 后续前端发送请求的时候,需要把 token 放到 Authorization 中,以 Bearer 的形式传过来 +// SessionProvider 默认情况下,产生的 Session 一个 token, +// 而如何返回,以及如何携带,取决于具体的 TokenCarrier 实现 // 很多字段并没有暴露,如果你需要自定义,可以发 issue type SessionProvider struct { - client redis.Cmdable - m ijwt.Manager[session.Claims] - tokenHeader string // 认证的请求头(存放 token 的请求头 key) - atHeader string // 暴露到外部的资源请求头 - rtHeader string // 暴露到外部的刷新请求头 - // 这个是长 token 的过期时间 - expiration time.Duration + client redis.Cmdable + m ijwt.Manager[session.Claims] + TokenCarrier session.TokenCarrier + expiration time.Duration } -// UpdateClaims 在这个实现里面,claims 同时写进去了 -func (rsp *SessionProvider) UpdateClaims(ctx *gctx.Context, claims session.Claims) error { - accessToken, err := rsp.m.GenerateAccessToken(claims) +func (rsp *SessionProvider) Destroy(ctx *gctx.Context) error { + sess, err := rsp.Get(ctx) if err != nil { return err } - refreshToken, err := rsp.m.GenerateRefreshToken(claims) + // 清除 token + rsp.TokenCarrier.Clear(ctx) + return sess.Destroy(ctx) +} + +// UpdateClaims 在这个实现里面,claims 同时写进去了 +func (rsp *SessionProvider) UpdateClaims(ctx *gctx.Context, claims session.Claims) error { + accessToken, err := rsp.m.GenerateAccessToken(claims) if err != nil { return err } - ctx.Header(rsp.atHeader, accessToken) - ctx.Header(rsp.rtHeader, refreshToken) + rsp.TokenCarrier.Inject(ctx, accessToken) return nil } func (rsp *SessionProvider) RenewAccessToken(ctx *ginx.Context) error { // 此时这里应该放着 RefreshToken - rt := rsp.extractTokenString(ctx) - jwtClaims, err := rsp.m.VerifyRefreshToken(rt) + rt := rsp.TokenCarrier.Extract(ctx) + jwtClaims, err := rsp.m.VerifyAccessToken(rt) if err != nil { return err } claims := jwtClaims.Data - sess := newRedisSession(claims.SSID, rsp.expiration, rsp.client, claims) - oldToken := sess.Get(ctx, keyRefreshToken).StringOrDefault("") - // refresh_token 只能用一次,不管成功与否 - _ = sess.Del(ctx, keyRefreshToken) - // 说明这个 rt 是已经用过的 refreshToken - // 或者 session 本身就已经过期了 - if oldToken != rt { - return errors.New("refresh_token 已经过期") - } accessToken, err := rsp.m.GenerateAccessToken(claims) - if err != nil { - return err - } - refreshToken, err := rsp.m.GenerateRefreshToken(claims) - if err != nil { - return err - } - ctx.Header(rsp.rtHeader, refreshToken) - ctx.Header(rsp.atHeader, accessToken) - return sess.Set(ctx, keyRefreshToken, refreshToken) + rsp.TokenCarrier.Inject(ctx, accessToken) + return err } // NewSession 的时候,要先把这个 data 写入到对应的 token 里面 @@ -100,42 +79,24 @@ func (rsp *SessionProvider) NewSession(ctx *gctx.Context, jwtData map[string]string, sessData map[string]any) (session.Session, error) { ssid := uuid.New().String() - claims := session.Claims{Uid: uid, SSID: ssid, Data: jwtData} + claims := session.Claims{Uid: uid, + SSID: ssid, + Expiration: time.Now().Add(rsp.expiration).UnixMilli(), + Data: jwtData} accessToken, err := rsp.m.GenerateAccessToken(claims) if err != nil { return nil, err } - refreshToken, err := rsp.m.GenerateRefreshToken(claims) - if err != nil { - return nil, err - } - - ctx.Header(rsp.rtHeader, refreshToken) - ctx.Header(rsp.atHeader, accessToken) - + rsp.TokenCarrier.Inject(ctx, accessToken) res := newRedisSession(ssid, rsp.expiration, rsp.client, claims) - // 将 refresh token 放进去 redis 里面 - // refresh token 应该只能用一次 - // 要设置超时时间 if sessData == nil { - sessData = make(map[string]any, 2) + sessData = make(map[string]any, 1) } sessData["uid"] = uid - sessData[keyRefreshToken] = refreshToken err = res.init(ctx, sessData) return res, err } -// extractTokenString 提取 token 字符串. -func (rsp *SessionProvider) extractTokenString(ctx *ginx.Context) string { - authCode := ctx.GetHeader(rsp.tokenHeader) - const bearerPrefix = "Bearer " - if strings.HasPrefix(authCode, bearerPrefix) { - return authCode[len(bearerPrefix):] - } - return "" -} - // Get 返回 Session,如果没有拿到 session 或者 session 已经过期,会返回 error func (rsp *SessionProvider) Get(ctx *gctx.Context) (session.Session, error) { val, _ := ctx.Get(session.CtxSessionKey) @@ -144,8 +105,8 @@ func (rsp *SessionProvider) Get(ctx *gctx.Context) (session.Session, error) { if ok { return res, nil } - - claims, err := rsp.m.VerifyAccessToken(rsp.extractTokenString(ctx)) + token := rsp.TokenCarrier.Extract(ctx) + claims, err := rsp.m.VerifyAccessToken(token) if err != nil { return nil, err } @@ -153,19 +114,15 @@ func (rsp *SessionProvider) Get(ctx *gctx.Context) (session.Session, error) { return res, nil } -// NewSessionProvider 长短 token + session 机制。短 token 的过期时间是一小时 -// 长 token 的过期时间是 30 天 -func NewSessionProvider(client redis.Cmdable, jwtKey string) *SessionProvider { +// NewSessionProvider 用于管理 Session +func NewSessionProvider(client redis.Cmdable, jwtKey string, + expiration time.Duration) *SessionProvider { // 长 token 过期时间,被看做是 Session 的过期时间 - expiration := time.Hour * 24 * 30 - m := ijwt.NewManagement[session.Claims](ijwt.NewOptions(time.Hour, jwtKey), - ijwt.WithRefreshJWTOptions[session.Claims](ijwt.NewOptions(expiration, jwtKey))) + m := ijwt.NewManagement[session.Claims](ijwt.NewOptions(expiration, jwtKey)) return &SessionProvider{ - client: client, - atHeader: "X-Access-Token", - rtHeader: "X-Refresh-Token", - tokenHeader: "Authorization", - m: m, - expiration: expiration, + client: client, + TokenCarrier: header.NewTokenCarrier(), + m: m, + expiration: expiration, } } diff --git a/session/redis/provider_test.go b/session/redis/provider_test.go index 5afb89e..0583e19 100644 --- a/session/redis/provider_test.go +++ b/session/redis/provider_test.go @@ -17,8 +17,6 @@ package redis import ( - "context" - "net/http" "net/http/httptest" "testing" "time" @@ -40,30 +38,6 @@ type ProviderTestSuite struct { e2e.BaseSuite } -func (s *ProviderTestSuite) TestRenewSession() { - sp := NewSessionProvider(s.RDB, "session") - req, err := http.NewRequest(http.MethodGet, "localhost:8080/hello", nil) - require.NoError(s.T(), err) - writer := httptest.NewRecorder() - gxCtx := &gctx.Context{ - Context: &gin.Context{ - Request: req, - Writer: &e2e.GinResponseWriter{ResponseWriter: writer}, - }, - } - sess, err := sp.NewSession(gxCtx, 123, map[string]string{"jwtKey1": "jwtVal1"}, map[string]any{"sessKe1": "sessVal1"}) - require.NoError(s.T(), err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - err = sess.Set(ctx, "sessKey2", "sessVal2") - require.NoError(s.T(), err) - // 先把 refresh token 取出来,放过去 req 的 header,从而模拟 renew 的请求 - rt := writer.Header().Get("X-Refresh-Token") - req.Header.Set("Authorization", "Bearer "+rt) - err = sp.RenewAccessToken(gxCtx) - require.NoError(s.T(), err) -} - func TestSessionProvider_UpdateClaims(t *testing.T) { testCases := []struct { name string @@ -89,7 +63,7 @@ func TestSessionProvider_UpdateClaims(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() client := tc.mock(ctrl) - sp := NewSessionProvider(client, "123") + sp := NewSessionProvider(client, "123", time.Minute) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) @@ -102,9 +76,10 @@ func TestSessionProvider_UpdateClaims(t *testing.T) { Context: ctx, } newCl := session.Claims{ - Uid: 234, - SSID: "ssid_123", - Data: map[string]string{"hello": "nihao"}} + Uid: 234, + SSID: "ssid_123", + Expiration: 123, + Data: map[string]string{"hello": "nihao"}} err = sp.UpdateClaims(gtx, newCl) assert.Equal(t, tc.wantErr, err) @@ -116,11 +91,6 @@ func TestSessionProvider_UpdateClaims(t *testing.T) { require.NoError(t, err) cl := rc.Data assert.Equal(t, newCl, cl) - token = ctx.Writer.Header().Get("X-Refresh-Token") - rc, err = sp.m.VerifyAccessToken(token) - require.NoError(t, err) - cl = rc.Data - assert.Equal(t, newCl, cl) }) } } @@ -157,7 +127,7 @@ func TestSessionProvider_NewSession(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() client := tc.mock(ctrl) - sp := NewSessionProvider(client, tc.key) + sp := NewSessionProvider(client, tc.key, time.Minute) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) sess, err := sp.NewSession(&gctx.Context{ @@ -172,6 +142,8 @@ func TestSessionProvider_NewSession(t *testing.T) { cl := rs.Claims() assert.True(t, len(cl.SSID) > 0) cl.SSID = "" + assert.Greater(t, cl.Expiration, int64(0)) + cl.Expiration = 0 assert.Equal(t, session.Claims{ Uid: 123, Data: map[string]string{"hello": "world"}, diff --git a/session/redis/session.go b/session/redis/session.go index 4d9ae66..327269e 100644 --- a/session/redis/session.go +++ b/session/redis/session.go @@ -28,7 +28,7 @@ var _ session.Session = &Session{} // Session 生命周期应该和 http 请求保持一致 type Session struct { client redis.Cmdable - // key 是 ssid 拼接而成。注意,它不是 access token,也不是 refresh token + // key 是 ssid 拼接而成。注意,它不是 access token key string claims session.Claims expiration time.Duration diff --git a/session/types.go b/session/types.go index ebf9ec4..f734683 100644 --- a/session/types.go +++ b/session/types.go @@ -51,13 +51,16 @@ type Provider interface { // 也就是,用户可以预期拿到的 Session 永远是没有过期,直接可用的 Get(ctx *gctx.Context) (Session, error) + // Destroy 销毁 Session,一般用在退出登录这种地方 + Destroy(ctx *gctx.Context) error + // UpdateClaims 修改 claims 的数据 // 但是因为 jwt 本身是不可变的,所以实际上这里是重新生成了一个 jwt 的 token // 必须传入正确的 SSID UpdateClaims(ctx *gctx.Context, claims Claims) error // RenewAccessToken 刷新并且返回一个新的 access token - // 这个过程会校验长 token 的合法性 + // 注意,必须是之前的 AccessToken 快要过期但是还没过期的时候 RenewAccessToken(ctx *gctx.Context) error } @@ -65,6 +68,8 @@ type Claims struct { Uid int64 SSID string Data map[string]string + // 过期时间。毫秒数 + Expiration int64 } func (c Claims) Get(key string) ekit.AnyValue { @@ -74,3 +79,10 @@ func (c Claims) Get(key string) ekit.AnyValue { } return ekit.AnyValue{Val: val} } + +// TokenCarrier 用于决定是使用 Header 还是使用 Cookie 来作为 +type TokenCarrier interface { + Inject(ctx *gctx.Context, value string) + Extract(ctx *gctx.Context) string + Clear(ctx *gctx.Context) +}