diff --git a/auth/middleware.go b/auth/middleware.go index 6be4f39..363f2c6 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "strings" "time" @@ -51,33 +52,66 @@ func sessionCallbacks(request *http.Request, writer http.ResponseWriter) *http.R } func reqisterUser(request *http.Request, writer http.ResponseWriter, db *gorm.DB) *http.Request { - if token, err := getToken(request); err == nil { - device := &model.Device{} - if db.Where("token = ?", token).Find(device).RecordNotFound() { - log.Info().Str("token", token).Msg("no device with token found") + token, err := getToken(request) + if err != nil { + msg := "Issue fetching token" + log.Info().Msg(msg) + http.Error(writer, msg, 400) + return request + } + + device := &model.Device{} + if db.Where("token = ?", token).Find(device).RecordNotFound() { + msg := "No device with token found" + log.Info().Str("token", token).Msg(msg) + http.Error(writer, msg, 400) + return request + } + + user := &model.User{} + if db.Find(user, device.UserID).RecordNotFound() { + log.Panic().Int("userID", device.UserID).Int("deviceID", device.ID).Msg("User not found") + } + log.Info().Int("userid", device.UserID).Str("username", user.Name).Msg("User found") + + impersonate := request.Header.Get("X-Traggo-Impersonate") + if impersonate != "" { + if !user.Admin { + msg := "Trying to impersonate without being admin" + log.Info().Str("impersonate", impersonate).Msg(msg) + http.Error(writer, msg, 403) return request } - - user := &model.User{} - if db.Find(user, device.UserID).RecordNotFound() { - log.Panic().Int("userID", device.UserID).Int("deviceID", device.ID).Msg("User not found") + userID, err := strconv.Atoi(impersonate) + if err != nil { + msg := "Unable to parse impersonation header" + log.Info().Str("impersonate", impersonate).Msg(msg) + http.Error(writer, msg, 400) + return request + } + impersonateUser := &model.User{} + if db.Find(impersonateUser, userID).RecordNotFound() { + msg := "Impersonation user not found" + log.Info().Int("userid", userID).Msg(msg) + http.Error(writer, msg, 400) + return request } + user = impersonateUser + log.Info().Str("username", user.Name).Msg("Impersonation user") + } - if device.ActiveAt.Before(time.Now().Add(5 * -time.Minute)) { - log.Debug().Int("deviceId", device.ID).Str("deviceName", device.Name).Msg("update device activeAt") - device.ActiveAt = timeNow() - db.Save(device) + if device.ActiveAt.Before(time.Now().Add(5 * -time.Minute)) { + log.Debug().Int("deviceId", device.ID).Str("deviceName", device.Name).Msg("update device activeAt") + device.ActiveAt = timeNow() + db.Save(device) - if cookie, err := request.Cookie(traggoSession); err == nil && cookie != nil { - cookie.MaxAge = device.Type.Seconds() - http.SetCookie(writer, cookie) - } + if cookie, err := request.Cookie(traggoSession); err == nil && cookie != nil { + cookie.MaxAge = device.Type.Seconds() + http.SetCookie(writer, cookie) } - - return request.WithContext(WithUser(WithDevice(request.Context(), device), user)) } - return request + return request.WithContext(WithUser(WithDevice(request.Context(), device), user)) } func getToken(request *http.Request) (string, error) { diff --git a/auth/middleware_test.go b/auth/middleware_test.go index ec8c2aa..2c4459e 100644 --- a/auth/middleware_test.go +++ b/auth/middleware_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "io" "net/http" "net/http/httptest" "testing" @@ -196,6 +197,129 @@ func TestMiddleware_destroySession_destroysCookie(t *testing.T) { assert.Equal(t, "traggo=; Max-Age=0", cookieHeader) } +func TestMiddleware_impersonate_no_admin(t *testing.T) { + now := test.Time("2018-06-30T18:30:00Z") + timeDispose := fakeTime(now) + defer timeDispose() + + db := test.InMemoryDB(t) + defer db.Close() + builder := db.User(1) + builder.NewDevice(2, "abc", "test") + spy := &requestSpy{} + recorder := httptest.NewRecorder() + + request := httptest.NewRequest("GET", "/test?token=abc", nil) + request.Header.Set("X-Traggo-Impersonate", "empty") + Middleware(db.DB)(spy).ServeHTTP(recorder, request) + + response := recorder.Result() + assert.Equal(t, 403, response.StatusCode) + + bodyBytes, _ := io.ReadAll(response.Body) + bodyString := string(bodyBytes) + assert.Equal(t, "Trying to impersonate without being admin\n", bodyString) + + ctx := spy.req.Context() + assert.Nil(t, GetUser(ctx)) + assert.Nil(t, GetDevice(ctx)) +} + +func TestMiddleware_impersonate_invalid_personate_header(t *testing.T) { + now := test.Time("2018-06-30T18:30:00Z") + timeDispose := fakeTime(now) + defer timeDispose() + + db := test.InMemoryDB(t) + defer db.Close() + builder := db.User(1) + user := builder.User + user.Admin = true + db.Save(user) + builder.NewDevice(2, "abc", "test") + spy := &requestSpy{} + recorder := httptest.NewRecorder() + + request := httptest.NewRequest("GET", "/test?token=abc", nil) + request.Header.Set("X-Traggo-Impersonate", "invalid") + Middleware(db.DB)(spy).ServeHTTP(recorder, request) + + response := recorder.Result() + assert.Equal(t, 400, response.StatusCode) + + bodyBytes, _ := io.ReadAll(response.Body) + bodyString := string(bodyBytes) + assert.Equal(t, "Unable to parse impersonation header\n", bodyString) + + ctx := spy.req.Context() + assert.Nil(t, GetUser(ctx)) + assert.Nil(t, GetDevice(ctx)) +} + +func TestMiddleware_impersonate_non_existing_user(t *testing.T) { + now := test.Time("2018-06-30T18:30:00Z") + timeDispose := fakeTime(now) + defer timeDispose() + + db := test.InMemoryDB(t) + defer db.Close() + builder := db.User(1) + user := builder.User + user.Admin = true + db.Save(user) + builder.NewDevice(2, "abc", "test") + spy := &requestSpy{} + recorder := httptest.NewRecorder() + + request := httptest.NewRequest("GET", "/test?token=abc", nil) + request.Header.Set("X-Traggo-Impersonate", "42") + Middleware(db.DB)(spy).ServeHTTP(recorder, request) + + response := recorder.Result() + assert.Equal(t, 400, response.StatusCode) + + bodyBytes, _ := io.ReadAll(response.Body) + bodyString := string(bodyBytes) + assert.Equal(t, "Impersonation user not found\n", bodyString) + + ctx := spy.req.Context() + assert.Nil(t, GetUser(ctx)) + assert.Nil(t, GetDevice(ctx)) +} + +func TestMiddleware_impersonate_happy(t *testing.T) { + now := test.Time("2018-06-30T18:30:00Z") + timeDispose := fakeTime(now) + defer timeDispose() + + db := test.InMemoryDB(t) + defer db.Close() + admin_builder := db.User(1) + admin_user := admin_builder.User + admin_user.Admin = true + db.Save(admin_user) + admin_device := admin_builder.NewDevice(2, "abc", "test") + + builder := db.User(2) + user := builder.User + + spy := &requestSpy{} + recorder := httptest.NewRecorder() + + request := httptest.NewRequest("GET", "/test?token=abc", nil) + request.Header.Set("X-Traggo-Impersonate", "2") + Middleware(db.DB)(spy).ServeHTTP(recorder, request) + + response := recorder.Result() + assert.Equal(t, 200, response.StatusCode) + + ctx := spy.req.Context() + assert.Equal(t, &user, GetUser(ctx)) + + admin_device.ActiveAt = now + assert.Equal(t, &admin_device, GetDevice(ctx)) +} + func TestGetCreateSession_panicsWhenMiddlewareWasNotExecuted(t *testing.T) { assert.Panics(t, func() { GetCreateSession(context.Background())