| 
 | 1 | +// Copyright (c) 2022 Gitpod GmbH. All rights reserved.  | 
 | 2 | +// Licensed under the GNU Affero General Public License (AGPL).  | 
 | 3 | +// See License-AGPL.txt in the project root for license information.  | 
 | 4 | + | 
 | 5 | +package db  | 
 | 6 | + | 
 | 7 | +import (  | 
 | 8 | +	"context"  | 
 | 9 | +	"database/sql/driver"  | 
 | 10 | +	"errors"  | 
 | 11 | +	"fmt"  | 
 | 12 | +	"strings"  | 
 | 13 | +	"time"  | 
 | 14 | + | 
 | 15 | +	"github.com/google/uuid"  | 
 | 16 | +	"gorm.io/gorm"  | 
 | 17 | +)  | 
 | 18 | + | 
 | 19 | +type PersonalAccessToken struct {  | 
 | 20 | +	ID             uuid.UUID `gorm:"primary_key;column:id;type:varchar;size:255;" json:"id"`  | 
 | 21 | +	UserID         uuid.UUID `gorm:"column:userId;type:varchar;size:255;" json:"userId"`  | 
 | 22 | +	Hash           string    `gorm:"column:hash;type:varchar;size:255;" json:"hash"`  | 
 | 23 | +	Name           string    `gorm:"column:name;type:varchar;size:255;" json:"name"`  | 
 | 24 | +	Description    string    `gorm:"column:description;type:varchar;size:255;" json:"description"`  | 
 | 25 | +	Scopes         Scopes    `gorm:"column:scopes;type:text;size:65535;" json:"scopes"`  | 
 | 26 | +	ExpirationTime time.Time `gorm:"column:expirationTime;type:timestamp;" json:"expirationTime"`  | 
 | 27 | +	CreatedAt      time.Time `gorm:"column:createdAt;type:timestamp;default:CURRENT_TIMESTAMP(6);" json:"createdAt"`  | 
 | 28 | +	LastModified   time.Time `gorm:"column:_lastModified;type:timestamp;default:CURRENT_TIMESTAMP(6);" json:"_lastModified"`  | 
 | 29 | + | 
 | 30 | +	// deleted is reserved for use by db-sync.  | 
 | 31 | +	_ bool `gorm:"column:deleted;type:tinyint;default:0;" json:"deleted"`  | 
 | 32 | +}  | 
 | 33 | + | 
 | 34 | +type Scopes []string  | 
 | 35 | + | 
 | 36 | +// TableName sets the insert table name for this struct type  | 
 | 37 | +func (d *PersonalAccessToken) TableName() string {  | 
 | 38 | +	return "d_b_personal_access_token"  | 
 | 39 | +}  | 
 | 40 | + | 
 | 41 | +func GetToken(ctx context.Context, conn *gorm.DB, id uuid.UUID) (PersonalAccessToken, error) {  | 
 | 42 | +	var token PersonalAccessToken  | 
 | 43 | + | 
 | 44 | +	db := conn.WithContext(ctx)  | 
 | 45 | + | 
 | 46 | +	db = db.Where("id = ?", id).First(&token)  | 
 | 47 | +	if db.Error != nil {  | 
 | 48 | +		return PersonalAccessToken{}, fmt.Errorf("Failed to retrieve token: %w", db.Error)  | 
 | 49 | +	}  | 
 | 50 | + | 
 | 51 | +	return token, nil  | 
 | 52 | +}  | 
 | 53 | + | 
 | 54 | +func CreateToken(ctx context.Context, conn *gorm.DB, req PersonalAccessToken) (PersonalAccessToken, error) {  | 
 | 55 | +	if req.UserID == uuid.Nil {  | 
 | 56 | +		return PersonalAccessToken{}, fmt.Errorf("Invalid or empty userID")  | 
 | 57 | +	}  | 
 | 58 | +	if req.Hash == "" {  | 
 | 59 | +		return PersonalAccessToken{}, fmt.Errorf("Token hash required")  | 
 | 60 | +	}  | 
 | 61 | +	if req.Name == "" {  | 
 | 62 | +		return PersonalAccessToken{}, fmt.Errorf("Token name required")  | 
 | 63 | +	}  | 
 | 64 | +	if req.ExpirationTime.IsZero() {  | 
 | 65 | +		return PersonalAccessToken{}, fmt.Errorf("Expiration time required")  | 
 | 66 | +	}  | 
 | 67 | + | 
 | 68 | +	token := PersonalAccessToken{  | 
 | 69 | +		ID:             req.ID,  | 
 | 70 | +		UserID:         req.UserID,  | 
 | 71 | +		Hash:           req.Hash,  | 
 | 72 | +		Name:           req.Name,  | 
 | 73 | +		Description:    req.Description,  | 
 | 74 | +		Scopes:         req.Scopes,  | 
 | 75 | +		ExpirationTime: req.ExpirationTime,  | 
 | 76 | +		CreatedAt:      time.Now().UTC(),  | 
 | 77 | +		LastModified:   time.Now().UTC(),  | 
 | 78 | +	}  | 
 | 79 | + | 
 | 80 | +	db := conn.WithContext(ctx).Create(req)  | 
 | 81 | +	if db.Error != nil {  | 
 | 82 | +		return PersonalAccessToken{}, fmt.Errorf("Failed to create token for user %s", req.UserID)  | 
 | 83 | +	}  | 
 | 84 | + | 
 | 85 | +	return token, nil  | 
 | 86 | +}  | 
 | 87 | + | 
 | 88 | +// Scan() and Value() allow having a list of strings as a type for Scopes  | 
 | 89 | +func (s *Scopes) Scan(src any) error {  | 
 | 90 | +	bytes, ok := src.([]byte)  | 
 | 91 | +	if !ok {  | 
 | 92 | +		return errors.New("src value cannot cast to []byte")  | 
 | 93 | +	}  | 
 | 94 | +	*s = strings.Split(string(bytes), ",")  | 
 | 95 | +	return nil  | 
 | 96 | +}  | 
 | 97 | +func (s Scopes) Value() (driver.Value, error) {  | 
 | 98 | +	if len(s) == 0 {  | 
 | 99 | +		return "", nil  | 
 | 100 | +	}  | 
 | 101 | +	return strings.Join(s, ","), nil  | 
 | 102 | +}  | 
0 commit comments