diff --git a/README.md b/README.md index 15978c9..eaa50a3 100644 --- a/README.md +++ b/README.md @@ -11,3 +11,7 @@ Useful commands: - To lint: `golangci-lint run -v` - To reformat long lines: `golines . -w --max-len=120 --base-formatter=gofmt` or to target a specific file replace `.` with the filename in the command + +## Important references + +[GitLab refresh tokens do not expire](https://gitlab.com/gitlab-org/gitlab/-/issues/340848#note_953496566) diff --git a/go.mod b/go.mod index 922406c..afa7ef9 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,16 @@ module github.com/SwissDataScienceCenter/renku-gateway-v2 go 1.19 require ( + github.com/go-co-op/gocron v1.18.0 github.com/go-redis/redis/v9 v9.0.0-rc.2 + github.com/go-redis/redismock/v9 v9.0.0-rc.2 github.com/oklog/ulid/v2 v2.1.0 - golang.org/x/net v0.2.0 + golang.org/x/net v0.5.0 ) require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/go-redis/redismock/v9 v9.0.0-rc.2 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect ) diff --git a/go.sum b/go.sum index f0b7c1d..488ae55 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/go-co-op/gocron v1.18.0 h1:SxTyJ5xnSN4byCq7b10LmmszFdxQlSQJod8s3gbnXxA= +github.com/go-co-op/gocron v1.18.0/go.mod h1:sD/a0Aadtw5CpflUJ/lpP9Vfdk979Wl1Sg33HPHg0FY= github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-redis/redis/v9 v9.0.0-rc.2 h1:IN1eI8AvJJeWHjMW/hlFAv2sAfvTun2DVksDDJ3a6a0= github.com/go-redis/redis/v9 v9.0.0-rc.2/go.mod h1:cgBknjwcBJa2prbnuHH/4k/Mlj4r0pWNV2HBanHujfY= @@ -65,6 +67,8 @@ github.com/onsi/gomega v1.24.1/go.mod h1:3AOiACssS3/MajrniINInwbfOOtfZvplPzuRSmv github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -97,12 +101,14 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -124,8 +130,8 @@ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -134,8 +140,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= diff --git a/internal/adapters/tokenrefresher/tokenrefresher.go b/internal/adapters/tokenrefresher/tokenrefresher.go new file mode 100644 index 0000000..2362bcb --- /dev/null +++ b/internal/adapters/tokenrefresher/tokenrefresher.go @@ -0,0 +1,140 @@ +// Package tokenrefresher refreshes oauth tokens stored by the gateway. +package tokenrefresher + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + "time" + + "github.com/SwissDataScienceCenter/renku-gateway-v2/internal/models" + "github.com/go-co-op/gocron" +) + +// tokenReponse struct required to unmarshal the response from a POST token refresh request +type tokenResponse struct { + AccessToken string `json:"access_token"` + Type string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + RefreshTokenExpiresIn int64 `json:"refresh_expires_in"` + Scope string `json:"scope"` + CreatedAt int64 `json:"created_at"` +} + +func (t tokenResponse) String() string { + return fmt.Sprintf("CreatedAt: %v, Type: %v, ExpiresIn: %v, RefreshTokenExpiresIn: %v", t.CreatedAt, t.Type, t.ExpiresIn, t.RefreshTokenExpiresIn) +} + +// RefresherTokenStore is an interface used for refreshing tokens stored by the gateway +type RefresherTokenStore interface { + GetRefreshToken(context.Context, string) (models.RefreshToken, error) + GetAccessToken(context.Context, string) (models.AccessToken, error) + SetRefreshToken(context.Context, models.RefreshToken) error + SetAccessToken(context.Context, models.AccessToken) error + GetExpiringAccessTokenIDs(context.Context, time.Time, time.Time) ([]string, error) +} + +// ScheduleRefreshExpiringTokens intialises a gocron job to run refreshExpiringTokens at a specified interval +func ScheduleRefreshExpiringTokens(ctx context.Context, tokenStore RefresherTokenStore, gitlabClientID string, gitlabClientSecret string, minsToExpiration int) error { + s := gocron.NewScheduler(time.UTC) + job, err := s.Every(minsToExpiration).Minutes().Do(refreshExpiringTokens, ctx, tokenStore, gitlabClientID, gitlabClientSecret, minsToExpiration) + s.StartBlocking() + if err != nil { + log.Printf("Starting gocron job failed: %s\n", err) + } else { + log.Printf("Job starting: %v\n", job) + } + return err +} + +// refreshExpiringTokens refreshes tokens in the token store expiring in the next minsToExpiration minutes +func refreshExpiringTokens(ctx context.Context, tokenStore RefresherTokenStore, clientID string, clientSecret string, minsToExpiration int) error { + // Get a list of expiring access tokens ids in the next minsToExpiration minutes + expiringTokenIDs, err := tokenStore.GetExpiringAccessTokenIDs(ctx, time.Now(), time.Now().Add(time.Minute*time.Duration(minsToExpiration))) + if err != nil { + log.Printf("GetExpiringAccessTokenIDs failed: %s\n", err) + return err + } + + // For each token id expiring in the next minsToExpiration minutes + for _, expiringTokenID := range expiringTokenIDs { + + // Get the refresh and access tokens associated with the token ID + myRefreshToken, err := tokenStore.GetRefreshToken(ctx, expiringTokenID) + if err != nil { + log.Printf("GetRefreshToken failed: %s\n", err) + return err + } + + myAccessToken, err := tokenStore.GetAccessToken(ctx, expiringTokenID) + if err != nil { + log.Printf("GetAccessToken failed: %s\n", err) + return err + } + + // Set the parameters required to refresh the tokens + params := url.Values{} + params.Add("client_id", clientID) + params.Add("client_secret", clientSecret) + params.Add("refresh_token", myRefreshToken.Value) + params.Add("grant_type", "refresh_token") + + // Send the POST request to refresh the tokens + resp, err := http.PostForm(myAccessToken.URL, params) + if err != nil { + log.Printf("Request Failed: %s\n", err) + return err + } + defer resp.Body.Close() + + // Decode JSON returned from the POST refresh request into a tokenResponse + token := tokenResponse{} + err = json.NewDecoder(resp.Body).Decode(&token) + if err != nil { + log.Printf("Decoding body failed: %s\n", err) + return err + } + + log.Printf("New token received: %v\n", token) + + // Calculate the UNIX timestamp at which the newly refreshed access and refresh tokens will expire + accessTokenExpiration := time.Unix(token.CreatedAt+token.ExpiresIn, 0) + // Keycloak does not provide a created_at parameter. + // Therefore, if the value of token.CreatedAt is 0, + // we replace token.CreatedAt with time.Now() + if token.CreatedAt == 0 { + accessTokenExpiration = time.Now().Add(time.Second * time.Duration(token.ExpiresIn)) + } + + refreshTokenExpiration := time.Now().Add(time.Second * time.Duration(token.RefreshTokenExpiresIn)) + // Gitlab refresh tokens do not expire + // (see https://gitlab.com/gitlab-org/gitlab/-/issues/340848#note_953496566). + // Therefore, in the case that there is no refresh token expiration time, + // we set a refresh token expiration time of 0. + if token.RefreshTokenExpiresIn == 0 { + refreshTokenExpiration = time.Unix(0, 0) + } + + // Set the refreshed access and refresh token values into the token store + err = tokenStore.SetAccessToken(ctx, models.AccessToken{ + ID: myAccessToken.ID, + Value: token.AccessToken, + ExpiresAt: accessTokenExpiration, + URL: myAccessToken.URL, + Type: myAccessToken.Type, + }) + + err = tokenStore.SetRefreshToken(ctx, models.RefreshToken{ + ID: myRefreshToken.ID, + Value: token.RefreshToken, + ExpiresAt: refreshTokenExpiration, + }) + } + + log.Printf("%v expiring access tokens refreshed, evaluating again in %v minutes\n", len(expiringTokenIDs), minsToExpiration) + return err +} diff --git a/internal/adapters/tokenrefresher/tokenrefresher_test.go b/internal/adapters/tokenrefresher/tokenrefresher_test.go new file mode 100644 index 0000000..9d1df7a --- /dev/null +++ b/internal/adapters/tokenrefresher/tokenrefresher_test.go @@ -0,0 +1,332 @@ +package tokenrefresher + +import ( + "context" + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/SwissDataScienceCenter/renku-gateway-v2/internal/models" +) + +var ctx = context.Background() + +type DummyAdapter struct { + err error + accessToken models.AccessToken + refreshToken models.RefreshToken + tokenID string +} + +func (d *DummyAdapter) GetRefreshToken(context.Context, string) (models.RefreshToken, error) { + return d.refreshToken, d.err +} +func (d *DummyAdapter) GetAccessToken(context.Context, string) (models.AccessToken, error) { + return d.accessToken, d.err +} +func (d *DummyAdapter) SetRefreshToken(ctx context.Context, aRefreshToken models.RefreshToken) error { + d.refreshToken = aRefreshToken + return d.err +} +func (d *DummyAdapter) SetAccessToken(ctx context.Context, anAccessToken models.AccessToken) error { + d.accessToken = anAccessToken + return d.err +} +func (d *DummyAdapter) GetExpiringAccessTokenIDs(context.Context, time.Time, time.Time) ([]string, error) { + return []string{d.tokenID}, d.err +} + +func TestRefreshExpiringTokensGitlab(t *testing.T) { + + log.Printf("Testing GitLab access token refresh") + + // Set dummy values for the 'existing' access and refresh tokens, and the oauth client id and secret + tokenID := "rNDSNs005xrNvrgKZ5vJGCDqwA3VQ1MB" + refreshTokenValue := "QG2RX43C81P5SNS1GACEMNKVT3SDBS" + accessTokenValue := "C1SB4BC3HTP841TGVS4R4G5JEAVQT4W" + clientID := "iPG5UPqrV6LiXiziLbj0CBGbDvWdPWwG" + clientSecret := "9p9KBXSUj037qkR55mdS0yAAecBxbb8Q" + tokenType := "git" + + // Set the dummy values we want the access and refresh tokens to have after refreshing them. + refreshedAccessTokenValue := "6XGQJCST3BY1BZ7X5X78X2MLF0W1AUB5" + refreshedRefreshTokenValue := "5EU358RBY51B88OP0JJ5S15WPSTCSCX3" + refreshedTokenCreationTime := time.Now().Unix() + + // Set up test HTTP server that refresh requests will be sent to + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" { + log.Println(w, "Received POST request!") + err := r.ParseForm() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + t.Fatal(err) + } + + // Ensure the expected values are received by the test HTTP server + if refreshTokenValue == r.PostForm["refresh_token"][0] { + log.Printf("The refresh token posted is the correct value, %v\n", r.PostForm["refresh_token"][0]) + } else { + t.Errorf("The refresh token posted is NOT the correct value, got %v want %v\n", r.PostForm["refresh_token"][0], refreshTokenValue) + } + + if clientID == r.PostForm["client_id"][0] { + log.Printf("The client ID posted is the correct value, %v\n", r.PostForm["client_id"][0]) + } else { + t.Errorf("The client ID posted is NOT the correct value, got %v want %v\n", r.PostForm["client_id"][0], refreshTokenValue) + } + + if clientSecret == r.PostForm["client_secret"][0] { + log.Printf("The client secret posted is the correct value, %v\n", r.PostForm["client_secret"][0]) + } else { + t.Errorf("The client secret posted is NOT the correct value, got %v want %v\n", r.PostForm["client_secret"][0], refreshTokenValue) + } + + if "refresh_token" == r.PostForm["grant_type"][0] { + log.Printf("The grant_type posted is the correct value, %v\n", r.PostForm["grant_type"][0]) + } else { + t.Errorf("The grant_type posted is NOT the correct value, got %v want %v\n", r.PostForm["grant_type"][0], "refresh_token") + } + + // Return the refreshed token values, and the other values Gitlab returns from the test HTTP server + w.Header().Set("Content-Type", "application/json") + + responseData := tokenResponse{ + AccessToken: refreshedAccessTokenValue, + Type: "bearer", + ExpiresIn: 7200, + RefreshToken: refreshedRefreshTokenValue, + Scope: "api", + CreatedAt: refreshedTokenCreationTime, + } + + err = json.NewEncoder(w).Encode(&responseData) + if err != nil { + t.Fatal(err) + } + } + })) + defer srv.Close() + + // Initialise dummy token store + var myRefresherTokenStore RefresherTokenStore = &DummyAdapter{} + + // Create a refresh and access token in our dummy token store with the pre-refresh token values + err := myRefresherTokenStore.SetAccessToken(ctx, models.AccessToken{ + ID: tokenID, + Value: accessTokenValue, + ExpiresAt: time.Now().Add(time.Minute * 5), + URL: srv.URL, + Type: tokenType, + }) + if err != nil { + t.Fatal(err) + } + + err = myRefresherTokenStore.SetRefreshToken(ctx, models.RefreshToken{ + ID: tokenID, + Value: refreshTokenValue, + }) + if err != nil { + t.Fatal(err) + } + + // Refresh tokens expiring in the next 5 minutes + err = refreshExpiringTokens(ctx, myRefresherTokenStore, clientID, clientSecret, 5) + if err != nil { + t.Fatal(err) + } + + // Get the newly refreshed access and refresh tokens and ensure they contain the expected post-refresh values + myNewAccessToken, err := myRefresherTokenStore.GetAccessToken(ctx, tokenID) + if err != nil { + t.Fatal(err) + } + + myNewRefreshToken, err := myRefresherTokenStore.GetRefreshToken(ctx, tokenID) + if err != nil { + t.Fatal(err) + } + + if refreshedAccessTokenValue == myNewAccessToken.Value { + log.Printf("The new access token is the correct value, %v\n", myNewAccessToken.Value) + } else { + t.Errorf("The new access token received is NOT the correct value, got %v want %v\n", myNewAccessToken.Value, refreshedAccessTokenValue) + } + + if srv.URL == myNewAccessToken.URL { + log.Printf("The new access token URL is the correct value, %v\n", myNewAccessToken.URL) + } else { + t.Errorf("The new access token URL received is NOT the correct value, got %v want %v\n", myNewAccessToken.URL, srv.URL) + } + + if tokenType == myNewAccessToken.Type { + log.Printf("The new access token type is the correct value, %v\n", myNewAccessToken.Type) + } else { + t.Errorf("The new access token URL received is NOT the correct value, got %v want %v\n", myNewAccessToken.Type, tokenType) + } + + if refreshedTokenCreationTime+7200 == myNewAccessToken.ExpiresAt.Unix() { + log.Printf("The new access token expiration time is the correct value, %v\n", myNewAccessToken.ExpiresAt.Unix()) + } else { + t.Errorf("The new access token expiration time received is NOT the correct value, got %v want %v\n", myNewAccessToken.ExpiresAt.Unix(), refreshedTokenCreationTime+7200) + } + + if refreshedRefreshTokenValue == myNewRefreshToken.Value { + log.Printf("The new refresh token is the correct value, %v\n", myNewRefreshToken.Value) + } else { + t.Errorf("The new refresh token received is NOT the correct value, got %v want %v\n", myNewRefreshToken.Value, refreshedRefreshTokenValue) + } +} + +func TestRefreshExpiringTokensKeycloak(t *testing.T) { + + log.Printf("Testing Keycloak access token refresh") + + // Set dummy values for the 'existing' access and refresh tokens, and the oauth client id and secret + tokenID := "rNDSNs005xrNvrgKZ5vJGCDqwA3VQ1MB" + refreshTokenValue := "QG2RX43C81P5SNS1GACEMNKVT3SDBS" + accessTokenValue := "C1SB4BC3HTP841TGVS4R4G5JEAVQT4W" + clientID := "iPG5UPqrV6LiXiziLbj0CBGbDvWdPWwG" + clientSecret := "9p9KBXSUj037qkR55mdS0yAAecBxbb8Q" + tokenType := "keycloak" + + // Set the dummy values we want the access and refresh tokens to have after refreshing them. + refreshedAccessTokenValue := "6XGQJCST3BY1BZ7X5X78X2MLF0W1AUB5" + refreshedRefreshTokenValue := "5EU358RBY51B88OP0JJ5S15WPSTCSCX3" + refreshedTokenCreationTime := time.Now().Unix() + + // Set up test HTTP server that refresh requests will be sent to + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" { + log.Println(w, "Received POST request!") + err := r.ParseForm() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + t.Fatal(err) + } + + // Ensure the expected values are received by the test HTTP server + if refreshTokenValue == r.PostForm["refresh_token"][0] { + log.Printf("The refresh token posted is the correct value, %v\n", r.PostForm["refresh_token"][0]) + } else { + t.Errorf("The refresh token posted is NOT the correct value, got %v want %v\n", r.PostForm["refresh_token"][0], refreshTokenValue) + } + + if clientID == r.PostForm["client_id"][0] { + log.Printf("The client ID posted is the correct value, %v\n", r.PostForm["client_id"][0]) + } else { + t.Errorf("The client ID posted is NOT the correct value, got %v want %v\n", r.PostForm["client_id"][0], refreshTokenValue) + } + + if clientSecret == r.PostForm["client_secret"][0] { + log.Printf("The client secret posted is the correct value, %v\n", r.PostForm["client_secret"][0]) + } else { + t.Errorf("The client secret posted is NOT the correct value, got %v want %v\n", r.PostForm["client_secret"][0], refreshTokenValue) + } + + if "refresh_token" == r.PostForm["grant_type"][0] { + log.Printf("The grant_type posted is the correct value, %v\n", r.PostForm["grant_type"][0]) + } else { + t.Errorf("The grant_type posted is NOT the correct value, got %v want %v\n", r.PostForm["grant_type"][0], "refresh_token") + } + + // Return the refreshed token values, and the other values Keycloak returns from the test HTTP server + w.Header().Set("Content-Type", "application/json") + + responseData := tokenResponse{ + AccessToken: refreshedAccessTokenValue, + Type: "bearer", + ExpiresIn: 1800, + RefreshTokenExpiresIn: 86400, + RefreshToken: refreshedRefreshTokenValue, + Scope: "api", + } + + err = json.NewEncoder(w).Encode(&responseData) + if err != nil { + t.Fatal(err) + } + } + })) + defer srv.Close() + + // Initialise dummy token store + var myRefresherTokenStore RefresherTokenStore = &DummyAdapter{} + + // Create a refresh and access token in our dummy token store with the pre-refresh token values + err := myRefresherTokenStore.SetAccessToken(ctx, models.AccessToken{ + ID: tokenID, + Value: accessTokenValue, + ExpiresAt: time.Now().Add(time.Minute * 5), + URL: srv.URL, + Type: tokenType, + }) + if err != nil { + t.Fatal(err) + } + + err = myRefresherTokenStore.SetRefreshToken(ctx, models.RefreshToken{ + ID: tokenID, + Value: refreshTokenValue, + }) + if err != nil { + t.Fatal(err) + } + + // Refresh tokens expiring in the next 5 minutes + err = refreshExpiringTokens(ctx, myRefresherTokenStore, clientID, clientSecret, 5) + if err != nil { + t.Fatal(err) + } + + // Get the newly refreshed access and refresh tokens and ensure they contain the expected post-refresh values + myNewAccessToken, err := myRefresherTokenStore.GetAccessToken(ctx, tokenID) + if err != nil { + t.Fatal(err) + } + + myNewRefreshToken, err := myRefresherTokenStore.GetRefreshToken(ctx, tokenID) + if err != nil { + t.Fatal(err) + } + + if refreshedAccessTokenValue == myNewAccessToken.Value { + log.Printf("The new access token is the correct value, %v\n", myNewAccessToken.Value) + } else { + t.Errorf("The new access token received is NOT the correct value, got %v want %v\n", myNewAccessToken.Value, refreshedAccessTokenValue) + } + + if srv.URL == myNewAccessToken.URL { + log.Printf("The new access token URL is the correct value, %v\n", myNewAccessToken.URL) + } else { + t.Errorf("The new access token URL received is NOT the correct value, got %v want %v\n", myNewAccessToken.URL, srv.URL) + } + + if tokenType == myNewAccessToken.Type { + log.Printf("The new access token type is the correct value, %v\n", myNewAccessToken.Type) + } else { + t.Errorf("The new access token URL received is NOT the correct value, got %v want %v\n", myNewAccessToken.Type, tokenType) + } + + if refreshedTokenCreationTime+1800 == myNewAccessToken.ExpiresAt.Unix() { + log.Printf("The new access token expiration time is the correct value, %v\n", myNewAccessToken.ExpiresAt.Unix()) + } else { + t.Errorf("The new access token expiration time received is NOT the correct value, got %v want %v\n", myNewAccessToken.ExpiresAt.Unix(), refreshedTokenCreationTime+7200) + } + + if refreshedRefreshTokenValue == myNewRefreshToken.Value { + log.Printf("The new refresh token is the correct value, %v\n", myNewRefreshToken.Value) + } else { + t.Errorf("The new refresh token received is NOT the correct value, got %v want %v\n", myNewRefreshToken.Value, refreshedRefreshTokenValue) + } + + if refreshedTokenCreationTime+86400 == myNewRefreshToken.ExpiresAt.Unix() { + log.Printf("The new refresh token expiration time is the correct value, %v\n", myNewRefreshToken.ExpiresAt.Unix()) + } else { + t.Errorf("The new refresh token received is NOT the correct value, got %v want %v\n", myNewRefreshToken.ExpiresAt.Unix(), refreshedTokenCreationTime+86400) + } +}