Skip to content

Commit

Permalink
♻️ refactor: Refactor user token
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Nov 30, 2024
1 parent cacfc1b commit efe69b7
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 45 deletions.
2 changes: 2 additions & 0 deletions common/config/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ var SystemName = "One Hub"
var ServerAddress = "http://localhost:3000"
var Debug = false

var OldTokenMaxId = 0

var Language = ""
var Footer = ""
var Logo = ""
Expand Down
18 changes: 18 additions & 0 deletions common/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func InitRedisClient() (err error) {
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
return
}

opt.DB = viper.GetInt("redis_db")
RDB = redis.NewClient(opt)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
Expand Down Expand Up @@ -88,3 +90,19 @@ func GetRedisClient() *redis.Client {
func ScriptRunCtx(ctx context.Context, script *redis.Script, keys []string, args ...interface{}) (interface{}, error) {
return script.Run(ctx, RDB, keys, args...).Result()
}

func RedisExists(key string) (bool, error) {
ctx := context.Background()
exists, err := RDB.Exists(ctx, key).Result()
return exists > 0, err
}

func RedisSAdd(key string, members ...interface{}) error {
ctx := context.Background()
return RDB.SAdd(ctx, key, members...).Err()
}

func RedisSIsMember(key string, member interface{}) (bool, error) {
ctx := context.Background()
return RDB.SIsMember(ctx, key, member).Result()
}
98 changes: 98 additions & 0 deletions common/user-token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package common

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"hash"
"sync"

"github.com/spf13/viper"
"github.com/sqids/sqids-go"
)

var (
hashidsMinLength = 15
hashids *sqids.Sqids

jwtSecretBytes = []byte{}
hmacPool = sync.Pool{
New: func() interface{} {
return hmac.New(sha256.New, jwtSecretBytes)
},
}
)

func InitUserToken() error {
tokenSecret := viper.GetString("user_token_secret")
hashidsSalt := viper.GetString("hashids_salt")

if tokenSecret == "" || hashidsSalt == "" {
return errors.New("token_secret or hashids_salt is not set")
}

var err error
hashids, err = sqids.New(sqids.Options{
MinLength: uint8(hashidsMinLength),
Alphabet: hashidsSalt,
})

jwtSecretBytes = []byte(tokenSecret)

return err
}

func GenerateToken(tokenID, userID int) (string, error) {
payload, err := hashids.Encode([]uint64{uint64(tokenID), uint64(userID)})
if err != nil {
return "", err
}

h := hmacPool.Get().(hash.Hash)
defer func() {
h.Reset()
hmacPool.Put(h)
}()

h.Write([]byte(payload))
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))

return payload + "_" + signature, nil
}

func ValidateToken(token string) (tokenID, userID int, err error) {
parts := bytes.SplitN([]byte(token), []byte("_"), 2)
if len(parts) != 2 {
return 0, 0, fmt.Errorf("无效的令牌")
}

payloadEncoded, receivedSignature := parts[0], parts[1]

h := hmacPool.Get().(hash.Hash)
defer func() {
h.Reset()
hmacPool.Put(h)
}()

h.Write(payloadEncoded)
expectedSignature := h.Sum(nil)

decodedSignature, err := base64.RawURLEncoding.DecodeString(string(receivedSignature))
if err != nil {
return 0, 0, fmt.Errorf("签名解码失败")
}

if !bytes.Equal(decodedSignature, expectedSignature) {
return 0, 0, fmt.Errorf("签名验证失败")
}

numbers := hashids.Decode(string(payloadEncoded))
if len(numbers) != 2 {
return 0, 0, fmt.Errorf("无效的令牌")
}

return int(numbers[0]), int(numbers[1]), nil
}
5 changes: 5 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ sql_dsn: "" # 设置之后将使用指定数据库而非 SQLite,请使用 MySQ
sqlite_path: "one-api.db" # sqlite 数据库文件路径
sqlite_busy_timeout: 3000 # sqlite 数据库繁忙超时时间,单位为毫秒,默认为 3000。
redis_conn_string: "" # 设置之后将使用指定 Redis 作为缓存,格式为 "redis://default:redispw@localhost:49153",未设置则不使用 Redis。
redis_db: 0 # redis 数据库,未设置则不使用 Redis。

memory_cache_enabled: false # 是否启用内存缓存,启用后将缓存部分数据,减少数据库查询次数。
sync_frequency: 600 # 在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 600 秒
Expand All @@ -29,6 +30,10 @@ batch_update_interval: 5 # 批量更新聚合的时间间隔,单位为秒,
batch_update_enabled: false # 启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 true 和 false,未设置则默认为 false
auto_price_updates: true # 启用自动更新价格,可选值为 true 和 false,默认为 true

# 令牌设置
user_token_secret: "" # 用户令牌密钥, 请设置32位随机字符串,丢失后用户令牌将无法验证,例如:vWVmFxp5YIOXuHhEod8jBcqiw0zKP2fk
hashids_salt: "" # hashids 盐, 丢失后用户令牌将无法验证,例如:QqGmDAFIdHzc7YBLMUPkEt6XuKjNnarwglZThvOW3y85Jo1bVxC02ipRs94fSe

# 全局设置
global:
api_rate_limit: 180 # 全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 180。
Expand Down
12 changes: 6 additions & 6 deletions controller/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ func GetPlaygroundToken(c *gin.Context) {
token, err := model.GetTokenByName(tokenName, userId)
if err != nil {
cleanToken := model.Token{
UserId: userId,
Name: tokenName,
Key: utils.GenerateKey(),
UserId: userId,
Name: tokenName,
// Key: utils.GenerateKey(),
CreatedTime: utils.GetTimestamp(),
AccessedTime: utils.GetTimestamp(),
ExpiredTime: 0,
Expand Down Expand Up @@ -117,9 +117,9 @@ func AddToken(c *gin.Context) {
}

cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: utils.GenerateKey(),
UserId: c.GetInt("id"),
Name: token.Name,
// Key: utils.GenerateKey(),
CreatedTime: utils.GetTimestamp(),
AccessedTime: utils.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ require (
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/sqids/sqids-go v0.4.1 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
Expand Down
9 changes: 9 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ func main() {

logger.SetupLogger()
logger.SysLog("One Hub " + config.Version + " started")

// Initialize user token
err := common.InitUserToken()
if err != nil {
logger.FatalLog("failed to initialize user token: " + err.Error())
}

// Initialize SQL Database
model.SetupDB()
defer model.CloseDB()
Expand All @@ -55,6 +62,8 @@ func main() {
// Initialize oidc
oidc.InitOIDCConfig()
relay_util.NewPricing()
model.HandleOldTokenMaxId()

initMemoryCache()
initSync()

Expand Down
12 changes: 2 additions & 10 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,14 @@ func tokenAuth(c *gin.Context, key string) {
return
}

parts := strings.Split(key, "-")
parts := strings.Split(key, "#")
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}

c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
Expand Down
46 changes: 46 additions & 0 deletions model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ var (
UserEnabledCacheKey = "user_enabled:%d"
UserRealtimeQuotaKey = "user_realtime_quota:%d"
UserRealtimeQuotaExpiration = 24 * time.Hour

OldUserTokensCacheKey = "old_user_tokens_cache"
)

func CacheGetTokenByKey(key string) (*Token, error) {
Expand Down Expand Up @@ -187,3 +189,47 @@ func CacheUpdateUserRealtimeQuota(id int, quota int) (int64, error) {

return newValue, nil
}

func HandleOldTokenMaxId() {
if config.OldTokenMaxId == 0 || !config.RedisEnabled {
return
}

// 检测OldUserTokensCacheKey是否存在
exists, _ := redis.RedisExists(OldUserTokensCacheKey)
if exists {
return
}
const batchSize = 1000
var offset int

for {
var tokenKeys []interface{}
result := DB.Model(&Token{}).
Where("id <= ?", config.OldTokenMaxId).
Limit(batchSize).
Offset(offset).
Pluck("key", &tokenKeys)

if result.Error != nil {
logger.SysError("查询旧token失败: " + result.Error.Error())
return
}

if len(tokenKeys) == 0 {
if offset == 0 {
logger.SysLog("没有找到旧token")
}
break
}

if err := redis.RedisSAdd(OldUserTokensCacheKey, tokenKeys...); err != nil {
logger.SysError("添加旧token到Redis失败: " + err.Error())
}

logger.SysLog(fmt.Sprintf("已处理 %d 个旧token", offset+len(tokenKeys)))
offset += batchSize

time.Sleep(100 * time.Millisecond)
}
}
Loading

0 comments on commit efe69b7

Please sign in to comment.