Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #305 from neicnordic/refactor/middleware-and-cache
Browse files Browse the repository at this point in the history
refactor middleware and cache in an attempt of clarifying its operation
  • Loading branch information
teemukataja authored Sep 4, 2023
2 parents 7babab0 + 208114a commit 94ab3e7
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 111 deletions.
45 changes: 20 additions & 25 deletions api/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import (
log "github.com/sirupsen/logrus"
)

var datasetsKey = "datasets"
// requestContextKey holds a name for the request context storage key
// which is used to store and get the permissions after passing middleware
const requestContextKey = "requestContextKey"

// TokenMiddleware performs access token verification and validation
// JWTs are verified and validated by the app, opaque tokens are sent to AAI for verification
// Successful auth results in list of authorised datasets
// Successful auth results in list of authorised datasets.
// The datasets are stored into a session cache for subsequent requests, and also
// to the current request context for use in the endpoints.
func TokenMiddleware() gin.HandlerFunc {

return func(c *gin.Context) {
Expand All @@ -23,11 +27,11 @@ func TokenMiddleware() gin.HandlerFunc {
if err != nil {
log.Debugf("no session cookie received")
}
var datasetCache session.DatasetCache
var cache session.Cache
var exists bool
if sessionCookie != "" {
log.Debug("session cookie received")
datasetCache, exists = session.Get(sessionCookie)
cache, exists = session.Get(sessionCookie)
}

if !exists {
Expand Down Expand Up @@ -57,14 +61,11 @@ func TokenMiddleware() gin.HandlerFunc {
// 200 OK with [] empty dataset list, when listing datasets (use case for sda-filesystem download tool)
// 404 dataset not found, when listing files from a dataset
// 401 unauthorised, when downloading a file
datasets := auth.GetPermissions(*visas)
datasetCache = session.DatasetCache{
Datasets: datasets,
}
cache.Datasets = auth.GetPermissions(*visas)

// Start a new session and store datasets under the session key
key := session.NewSessionKey()
session.Set(key, datasetCache)
session.Set(key, cache)
c.SetCookie(config.Config.Session.Name, // name
key, // value
int(config.Config.Session.Expiration)/1e9, // max age
Expand All @@ -77,31 +78,25 @@ func TokenMiddleware() gin.HandlerFunc {
}

// Store dataset list to request context, for use in the endpoint handlers
c = storeDatasets(c, datasetCache)
log.Debugf("storing %v to request context", cache)
c.Set(requestContextKey, cache)

// Forward request to the next endpoint handler
c.Next()
}

}

// storeDatasets stores the dataset list to the request context
func storeDatasets(c *gin.Context, datasets session.DatasetCache) *gin.Context {
log.Debugf("storing %v datasets to request context", datasets)

c.Set(datasetsKey, datasets)

return c
}

// GetDatasets extracts the dataset list from the request context
var GetDatasets = func(c *gin.Context) session.DatasetCache {
var datasetCache session.DatasetCache
cached, exists := c.Get(datasetsKey)
// GetCacheFromContext is a helper function that endpoints can use to get data
// stored to the *current* request context (not the session storage).
// The request context was populated by the middleware, which in turn uses the session storage.
var GetCacheFromContext = func(c *gin.Context) session.Cache {
var cache session.Cache
cached, exists := c.Get(requestContextKey)
if exists {
datasetCache = cached.(session.DatasetCache)
cache = cached.(session.Cache)
}
log.Debugf("returning %v from request context", cached)

return datasetCache
return cache
}
24 changes: 12 additions & 12 deletions api/middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ func TestTokenMiddleware_Success_NoCache(t *testing.T) {
// Now that we are modifying the request context, we need to place the context test inside the handler
expectedDatasets := []string{"dataset1", "dataset2"}
testEndpointWithContextData := func(c *gin.Context) {
datasets, _ := c.Get(datasetsKey)
if !reflect.DeepEqual(datasets.(session.DatasetCache).Datasets, expectedDatasets) {
datasets, _ := c.Get(requestContextKey)
if !reflect.DeepEqual(datasets.(session.Cache).Datasets, expectedDatasets) {
t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %s expected %s", datasets, expectedDatasets)
}
}
Expand Down Expand Up @@ -224,9 +224,9 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
originalGetCache := session.Get

// Substitute mock functions
session.Get = func(key string) (session.DatasetCache, bool) {
session.Get = func(key string) (session.Cache, bool) {
log.Warningf("session.Get %v", key)
cached := session.DatasetCache{
cached := session.Cache{
Datasets: []string{"dataset1", "dataset2"},
}

Expand All @@ -248,8 +248,8 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
// Now that we are modifying the request context, we need to place the context test inside the handler
expectedDatasets := []string{"dataset1", "dataset2"}
testEndpointWithContextData := func(c *gin.Context) {
datasets, _ := c.Get(datasetsKey)
if !reflect.DeepEqual(datasets.(session.DatasetCache).Datasets, expectedDatasets) {
datasets, _ := c.Get(requestContextKey)
if !reflect.DeepEqual(datasets.(session.Cache).Datasets, expectedDatasets) {
t.Errorf("TestTokenMiddleware_Success_FromCache failed, got %s expected %s", datasets, expectedDatasets)
}
}
Expand Down Expand Up @@ -284,13 +284,13 @@ func TestStoreDatasets(t *testing.T) {
c, _ := gin.CreateTestContext(w)

// Store data to request context
datasets := session.DatasetCache{
datasets := session.Cache{
Datasets: []string{"dataset1", "dataset2"},
}
modifiedContext := storeDatasets(c, datasets)
c.Set(requestContextKey, datasets)

// Verify that context has new data
storedDatasets := modifiedContext.Value(datasetsKey).(session.DatasetCache)
storedDatasets := c.Value(requestContextKey).(session.Cache)
if !reflect.DeepEqual(datasets, storedDatasets) {
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
}
Expand All @@ -304,13 +304,13 @@ func TestGetDatasets(t *testing.T) {
c, _ := gin.CreateTestContext(w)

// Store data to request context
datasets := session.DatasetCache{
datasets := session.Cache{
Datasets: []string{"dataset1", "dataset2"},
}
modifiedContext := storeDatasets(c, datasets)
c.Set(requestContextKey, datasets)

// Verify that context has new data
storedDatasets := GetDatasets(modifiedContext)
storedDatasets := GetCacheFromContext(c)
if !reflect.DeepEqual(datasets, storedDatasets) {
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
}
Expand Down
12 changes: 6 additions & 6 deletions api/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func ListBuckets(c *gin.Context) {
}

buckets := []Bucket{}
datasetCache := middleware.GetDatasets(c)
for _, dataset := range datasetCache.Datasets {
cache := middleware.GetCacheFromContext(c)
for _, dataset := range cache.Datasets {
datasetInfo, err := database.GetDatasetInfo(dataset)
if err != nil {
log.Errorf("Failed to get dataset information: %v", err)
Expand Down Expand Up @@ -123,8 +123,8 @@ func ListObjects(c *gin.Context) {
dataset := c.Param("dataset")

allowed := false
datasetCache := middleware.GetDatasets(c)
for _, known := range datasetCache.Datasets {
cache := middleware.GetCacheFromContext(c)
for _, known := range cache.Datasets {
if dataset == known {
allowed = true

Expand Down Expand Up @@ -244,8 +244,8 @@ func parseParams(c *gin.Context) *gin.Context {
path = string(protocolPattern.ReplaceAll([]byte(path), []byte("$1/$2")))
}

datasetCache := middleware.GetDatasets(c)
for _, dataset := range datasetCache.Datasets {
cache := middleware.GetCacheFromContext(c)
for _, dataset := range cache.Datasets {
// check that the path starts with the dataset name, but also that the
// path is only the dataset, or that the following character is a slash.
// This prevents wrong matches in cases like when one dataset name is a
Expand Down
14 changes: 7 additions & 7 deletions api/sda/sda.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func Datasets(c *gin.Context) {

// Retrieve dataset list from request context
// generated by the authentication middleware
datasetCache := middleware.GetDatasets(c)
cache := middleware.GetCacheFromContext(c)

// Return response
c.JSON(http.StatusOK, datasetCache.Datasets)
c.JSON(http.StatusOK, cache.Datasets)
}

// find looks for a dataset name in a list of datasets
Expand All @@ -60,11 +60,11 @@ var getFiles = func(datasetID string, ctx *gin.Context) ([]*database.FileInfo, i

// Retrieve dataset list from request context
// generated by the authentication middleware
datasetCache := middleware.GetDatasets(ctx)
cache := middleware.GetCacheFromContext(ctx)

log.Debugf("request to process files for dataset %s", sanitizeString(datasetID))

if find(datasetID, datasetCache.Datasets) {
if find(datasetID, cache.Datasets) {
// Get file metadata
files, err := database.GetFiles(datasetID)
if err != nil {
Expand Down Expand Up @@ -133,12 +133,12 @@ func Download(c *gin.Context) {
}

// Get datasets from request context, parsed previously by token middleware
datasetCache := middleware.GetDatasets(c)
cache := middleware.GetCacheFromContext(c)

// Verify user has permission to datafile
permission := false
for d := range datasetCache.Datasets {
if datasetCache.Datasets[d] == dataset {
for d := range cache.Datasets {
if cache.Datasets[d] == dataset {
permission = true

break
Expand Down
Loading

0 comments on commit 94ab3e7

Please sign in to comment.