diff --git a/client/manager_test.go b/client/manager_test.go index 38504a3239d..413e48da04a 100644 --- a/client/manager_test.go +++ b/client/manager_test.go @@ -16,9 +16,7 @@ import ( . "github.com/ory/hydra/client" "github.com/ory/hydra/compose" "github.com/ory/hydra/integration" - "github.com/ory/hydra/pkg" "github.com/ory/ladon" - "github.com/stretchr/testify/assert" ) var clientManagers = map[string]Storage{} @@ -103,18 +101,8 @@ func TestAuthenticateClient(t *testing.T) { Clients: map[string]Client{}, Hasher: &fosite.BCrypt{}, } - mem.CreateClient(&Client{ - ID: "1234", - Secret: "secret", - RedirectURIs: []string{"http://redirect"}, - }) - - c, err := mem.Authenticate("1234", []byte("secret1")) - pkg.AssertError(t, true, err) - c, err = mem.Authenticate("1234", []byte("secret")) - pkg.AssertError(t, false, err) - assert.Equal(t, "1234", c.ID) + TestHelperClientAuthenticate("", mem)(t) } func TestCreateGetDeleteClient(t *testing.T) { diff --git a/client/manager_test_helpers.go b/client/manager_test_helpers.go index 5ac5ab4293d..8543539d57a 100644 --- a/client/manager_test_helpers.go +++ b/client/manager_test_helpers.go @@ -2,9 +2,11 @@ package client import ( "testing" - "github.com/stretchr/testify/assert" "time" + "github.com/ory/fosite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) { @@ -14,9 +16,26 @@ func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) { RedirectURIs: []string{"http://redirect"}, TermsOfServiceURI: "foo", } - assert.Nil(t, m.CreateClient(c)) + assert.NoError(t, m.CreateClient(c)) assert.NotEmpty(t, c.ID) - assert.Nil(t, m.DeleteClient(c.ID)) + assert.NoError(t, m.DeleteClient(c.ID)) + } +} + +func TestHelperClientAuthenticate(k string, m Manager) func(t *testing.T) { + return func(t *testing.T) { + m.CreateClient(&Client{ + ID: "1234321", + Secret: "secret", + RedirectURIs: []string{"http://redirect"}, + }) + + c, err := m.Authenticate("1234321", []byte("secret1")) + require.NotNil(t, err) + + c, err = m.Authenticate("1234321", []byte("secret")) + require.Error(t, err) + assert.Equal(t, "1234321", c.ID) } } @@ -32,8 +51,9 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { RedirectURIs: []string{"http://redirect"}, TermsOfServiceURI: "foo", } + err = m.CreateClient(c) - assert.Nil(t, err) + assert.NoError(t, err) if err == nil { compare(t, c, k) } @@ -45,19 +65,20 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { RedirectURIs: []string{"http://redirect"}, TermsOfServiceURI: "foo", }) - assert.Nil(t, err) + assert.NoError(t, err) // RethinkDB delay time.Sleep(100 * time.Millisecond) d, err := m.GetClient(nil, "1234") - assert.Nil(t, err) + assert.NoError(t, err) + if err == nil { compare(t, d, k) } ds, err := m.GetClients() - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, ds, 2) assert.NotEqual(t, ds["1234"].ID, ds["2-1234"].ID) @@ -68,11 +89,11 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { RedirectURIs: []string{"http://redirect/new"}, TermsOfServiceURI: "bar", }) - assert.Nil(t, err) + assert.NoError(t, err) time.Sleep(100 * time.Millisecond) nc, err := m.GetConcreteClient("2-1234") - assert.Nil(t, err) + assert.NoError(t, err) if k != "http" { // http always returns an empty secret @@ -84,7 +105,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) { assert.Zero(t, len(nc.Contacts)) err = m.DeleteClient("1234") - assert.Nil(t, err) + assert.NoError(t, err) // RethinkDB delay time.Sleep(100 * time.Millisecond) diff --git a/jwk/manager_test.go b/jwk/manager_test.go index 0132822a44a..8bfc16e30dc 100644 --- a/jwk/manager_test.go +++ b/jwk/manager_test.go @@ -1,9 +1,7 @@ package jwk_test import ( - "crypto/rand" "fmt" - "io" "log" "net/http" "net/http/httptest" @@ -20,7 +18,6 @@ import ( . "github.com/ory/hydra/jwk" "github.com/ory/hydra/pkg" "github.com/ory/ladon" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -69,15 +66,7 @@ func init() { managers["http"] = httpManager } -func randomBytes(n int) ([]byte, error) { - bytes := make([]byte, n) - if _, err := io.ReadFull(rand.Reader, bytes); err != nil { - return []byte{}, errors.WithStack(err) - } - return bytes, nil -} - -var encryptionKey, _ = randomBytes(32) +var encryptionKey, _ = RandomBytes(32) func TestMain(m *testing.M) { connectToPG() @@ -132,48 +121,16 @@ func TestHTTPManagerPublicKeyGet(t *testing.T) { func TestManagerKey(t *testing.T) { ks, _ := testGenerator.Generate("") - priv := ks.Key("private") - pub := ks.Key("public") for name, m := range managers { t.Run(fmt.Sprintf("case=%s", name), func(t *testing.T) { - _, err := m.GetKey("faz", "baz") - assert.NotNil(t, err) - - err = m.AddKey("faz", First(priv)) - assert.Nil(t, err) - - time.Sleep(time.Millisecond * 100) - - got, err := m.GetKey("faz", "private") - assert.Nil(t, err) - assert.Equal(t, priv, got.Keys, "%s", name) - - err = m.AddKey("faz", First(pub)) - assert.Nil(t, err) - - time.Sleep(time.Millisecond * 100) - - got, err = m.GetKey("faz", "private") - assert.Nil(t, err) - assert.Equal(t, priv, got.Keys, "%s", name) - - got, err = m.GetKey("faz", "public") - assert.Nil(t, err) - assert.Equal(t, pub, got.Keys, "%s", name) - - err = m.DeleteKey("faz", "public") - assert.Nil(t, err) - - time.Sleep(time.Millisecond * 100) - - ks, err = m.GetKey("faz", "public") - assert.NotNil(t, err) + TestHelperManagerKey(m, ks)(t) }) } + priv := ks.Key("private") err := managers["http"].AddKey("nonono", First(priv)) - pkg.AssertError(t, true, err, "%s") + assert.NotNil(t, err) } func TestManagerKeySet(t *testing.T) { @@ -182,29 +139,10 @@ func TestManagerKeySet(t *testing.T) { for name, m := range managers { t.Run(fmt.Sprintf("case=%s", name), func(t *testing.T) { - _, err := m.GetKeySet("foo") - pkg.AssertError(t, true, err, name) - - err = m.AddKeySet("bar", ks) - assert.Nil(t, err) - - time.Sleep(time.Millisecond * 100) - - got, err := m.GetKeySet("bar") - assert.Nil(t, err) - assert.Equal(t, ks.Key("public"), got.Key("public"), name) - assert.Equal(t, ks.Key("private"), got.Key("private"), name) - - err = m.DeleteKeySet("bar") - assert.Nil(t, err) - - time.Sleep(time.Millisecond * 100) - - _, err = m.GetKeySet("bar") - assert.NotNil(t, err) + TestHelperManagerKeySet(m, ks)(t) }) } err := managers["http"].AddKeySet("nonono", ks) - pkg.AssertError(t, true, err, "%s") + assert.NotNil(t, err) } diff --git a/jwk/manager_test_helpers.go b/jwk/manager_test_helpers.go new file mode 100644 index 00000000000..7302d0eab0c --- /dev/null +++ b/jwk/manager_test_helpers.go @@ -0,0 +1,75 @@ +package jwk + +import ( + "crypto/rand" + "io" + "testing" + + "github.com/ory/hydra/pkg" + "github.com/pkg/errors" + "github.com/square/go-jose" + "github.com/stretchr/testify/assert" +) + +func RandomBytes(n int) ([]byte, error) { + bytes := make([]byte, n) + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + return []byte{}, errors.WithStack(err) + } + return bytes, nil +} + +func TestHelperManagerKey(m Manager, keys *jose.JsonWebKeySet) func(t *testing.T) { + pub := keys.Key("public") + priv := keys.Key("private") + + return func(t *testing.T) { + _, err := m.GetKey("faz", "baz") + assert.NotNil(t, err) + + err = m.AddKey("faz", First(priv)) + assert.Nil(t, err) + + got, err := m.GetKey("faz", "private") + assert.Nil(t, err) + assert.Equal(t, priv, got.Keys) + + err = m.AddKey("faz", First(pub)) + assert.Nil(t, err) + + got, err = m.GetKey("faz", "private") + assert.Nil(t, err) + assert.Equal(t, priv, got.Keys) + + got, err = m.GetKey("faz", "public") + assert.Nil(t, err) + assert.Equal(t, pub, got.Keys) + + err = m.DeleteKey("faz", "public") + assert.Nil(t, err) + + _, err = m.GetKey("faz", "public") + assert.NotNil(t, err) + } +} + +func TestHelperManagerKeySet(m Manager, keys *jose.JsonWebKeySet) func(t *testing.T) { + return func(t *testing.T) { + _, err := m.GetKeySet("foo") + pkg.AssertError(t, true, err) + + err = m.AddKeySet("bar", keys) + assert.Nil(t, err) + + got, err := m.GetKeySet("bar") + assert.Nil(t, err) + assert.Equal(t, keys.Key("public"), got.Key("public")) + assert.Equal(t, keys.Key("private"), got.Key("private")) + + err = m.DeleteKeySet("bar") + assert.Nil(t, err) + + _, err = m.GetKeySet("bar") + assert.NotNil(t, err) + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index ab046b1a1eb..a0d9b6ae7b2 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -1,9 +1,9 @@ package metrics import ( - "time" - "sync" "runtime" + "sync" + "time" ) type Metrics struct { @@ -91,27 +91,27 @@ type Snapshot struct { sync.RWMutex *Metrics *HTTPMetrics - Paths map[string]*PathMetrics `json:"paths"` - ID string `json:"id"` - UpTime int64 `json:"uptime"` - start time.Time `json:"-"` - MemorySnapshot *MemorySnapshot `json:"memory"` + Paths map[string]*PathMetrics `json:"paths"` + ID string `json:"id"` + UpTime int64 `json:"uptime"` + start time.Time `json:"-"` + MemorySnapshot *MemorySnapshot `json:"memory"` } type MemorySnapshot struct { - Alloc uint64 `json:"alloc"` - TotalAlloc uint64 `json:"totalAlloc"` - Sys uint64 `json:"sys"` - Lookups uint64 `json:"lookups"` - Mallocs uint64 `json:"mallocs"` - Frees uint64 `json:"frees"` - HeapAlloc uint64 `json:"heapAlloc"` - HeapSys uint64 `json:"heapSys"` - HeapIdle uint64 `json:"heapIdle"` - HeapInuse uint64 `json:"heapInuse"` + Alloc uint64 `json:"alloc"` + TotalAlloc uint64 `json:"totalAlloc"` + Sys uint64 `json:"sys"` + Lookups uint64 `json:"lookups"` + Mallocs uint64 `json:"mallocs"` + Frees uint64 `json:"frees"` + HeapAlloc uint64 `json:"heapAlloc"` + HeapSys uint64 `json:"heapSys"` + HeapIdle uint64 `json:"heapIdle"` + HeapInuse uint64 `json:"heapInuse"` HeapReleased uint64 `json:"heapReleased"` - HeapObjects uint64 `json:"heapObjects"` - NumGC uint32 `json:"numGC"` + HeapObjects uint64 `json:"heapObjects"` + NumGC uint32 `json:"numGC"` } func newMetrics() *Metrics { @@ -127,7 +127,7 @@ func (sw *Snapshot) GetUpTime() int64 { return sw.UpTime } -func (sw *Snapshot) Update() { +func (sw *Snapshot) Update() { sw.Lock() defer sw.Unlock() @@ -136,19 +136,19 @@ func (sw *Snapshot) Update() { // sw.MemorySnapshot = &(MemorySnapshot(m)) sw.MemorySnapshot = &MemorySnapshot{ - Alloc: m.Alloc, - TotalAlloc: m.TotalAlloc, - Sys: m.Sys, - Lookups: m.Lookups, - Mallocs: m.Mallocs, - Frees: m.Frees, - HeapAlloc: m.HeapAlloc, - HeapSys: m.HeapSys, - HeapIdle: m.HeapIdle, - HeapInuse: m.HeapInuse, + Alloc: m.Alloc, + TotalAlloc: m.TotalAlloc, + Sys: m.Sys, + Lookups: m.Lookups, + Mallocs: m.Mallocs, + Frees: m.Frees, + HeapAlloc: m.HeapAlloc, + HeapSys: m.HeapSys, + HeapIdle: m.HeapIdle, + HeapInuse: m.HeapInuse, HeapReleased: m.HeapReleased, - HeapObjects: m.HeapObjects, - NumGC: m.NumGC, + HeapObjects: m.HeapObjects, + NumGC: m.NumGC, } sw.UpTime = int64(time.Now().Sub(sw.start) / time.Second) diff --git a/metrics/middleware_test.go b/metrics/middleware_test.go index 6e694f30d0f..06055b5dff9 100644 --- a/metrics/middleware_test.go +++ b/metrics/middleware_test.go @@ -15,10 +15,11 @@ import ( "math/rand" "time" + "encoding/json" + "github.com/Sirupsen/logrus" "github.com/ory/herodot" "github.com/ory/hydra/health" - "encoding/json" ) func TestMiddleware(t *testing.T) { diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index a409c6bcf4e..55e52e3b47c 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -3,7 +3,6 @@ package oauth2 import ( "context" "fmt" - "net/url" "os" "testing" "time" @@ -62,62 +61,15 @@ func connectToMySQL() { clientManagers["mysql"] = s } -var defaultRequest = fosite.Request{ - RequestedAt: time.Now().Round(time.Second), - Client: &client.Client{ID: "foobar"}, - Scopes: fosite.Arguments{"fa", "ba"}, - GrantedScopes: fosite.Arguments{"fa", "ba"}, - Form: url.Values{"foo": []string{"bar", "baz"}}, - Session: &fosite.DefaultSession{Subject: "bar"}, -} - func TestCreateGetDeleteAuthorizeCodes(t *testing.T) { - ctx := context.Background() for k, m := range clientManagers { - t.Run(fmt.Sprintf("case=%s", k), func(t *testing.T) { - _, err := m.GetAuthorizeCodeSession(ctx, "4321", &fosite.DefaultSession{}) - assert.NotNil(t, err) - - err = m.CreateAuthorizeCodeSession(ctx, "4321", &defaultRequest) - require.Nil(t, err) - - res, err := m.GetAuthorizeCodeSession(ctx, "4321", &fosite.DefaultSession{}) - require.Nil(t, err) - AssertObjectKeysEqual(t, &defaultRequest, res, "Scopes", "GrantedScopes", "Form", "Session") - - err = m.DeleteAuthorizeCodeSession(ctx, "4321") - require.Nil(t, err) - - time.Sleep(100 * time.Millisecond) - - _, err = m.GetAuthorizeCodeSession(ctx, "4321", &fosite.DefaultSession{}) - assert.NotNil(t, err) - }) + t.Run(fmt.Sprintf("case=%s", k), TestHelperCreateGetDeleteAuthorizeCodes(m)) } } func TestCreateGetDeleteAccessTokenSession(t *testing.T) { - ctx := context.Background() for k, m := range clientManagers { - t.Run(fmt.Sprintf("case=%s", k), func(t *testing.T) { - _, err := m.GetAccessTokenSession(ctx, "4321", &fosite.DefaultSession{}) - assert.NotNil(t, err) - - err = m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) - require.Nil(t, err) - - res, err := m.GetAccessTokenSession(ctx, "4321", &fosite.DefaultSession{}) - require.Nil(t, err) - AssertObjectKeysEqual(t, &defaultRequest, res, "Scopes", "GrantedScopes", "Form", "Session") - - err = m.DeleteAccessTokenSession(ctx, "4321") - require.Nil(t, err) - - time.Sleep(100 * time.Millisecond) - - _, err = m.GetAccessTokenSession(ctx, "4321", &fosite.DefaultSession{}) - assert.NotNil(t, err) - }) + t.Run(fmt.Sprintf("case=%s", k), TestHelperCreateGetDeleteAccessTokenSession(m)) } } diff --git a/oauth2/fosite_store_test_helpers.go b/oauth2/fosite_store_test_helpers.go new file mode 100644 index 00000000000..26dcdcb7010 --- /dev/null +++ b/oauth2/fosite_store_test_helpers.go @@ -0,0 +1,65 @@ +package oauth2 + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/hydra/client" + "github.com/ory/hydra/pkg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultRequest = fosite.Request{ + RequestedAt: time.Now().Round(time.Second), + Client: &client.Client{ID: "foobar"}, + Scopes: fosite.Arguments{"fa", "ba"}, + GrantedScopes: fosite.Arguments{"fa", "ba"}, + Form: url.Values{"foo": []string{"bar", "baz"}}, + Session: &fosite.DefaultSession{Subject: "bar"}, +} + +func TestHelperCreateGetDeleteAuthorizeCodes(m pkg.FositeStorer) func(t *testing.T) { + return func(t *testing.T) { + ctx := context.Background() + _, err := m.GetAuthorizeCodeSession(ctx, "4321", &fosite.DefaultSession{}) + assert.NotNil(t, err) + + err = m.CreateAuthorizeCodeSession(ctx, "4321", &defaultRequest) + require.NoError(t, err) + + res, err := m.GetAuthorizeCodeSession(ctx, "4321", &fosite.DefaultSession{}) + require.NoError(t, err) + AssertObjectKeysEqual(t, &defaultRequest, res, "Scopes", "GrantedScopes", "Form", "Session") + + err = m.DeleteAuthorizeCodeSession(ctx, "4321") + require.NoError(t, err) + + _, err = m.GetAuthorizeCodeSession(ctx, "4321", &fosite.DefaultSession{}) + assert.NotNil(t, err) + } +} + +func TestHelperCreateGetDeleteAccessTokenSession(m pkg.FositeStorer) func(t *testing.T) { + return func(t *testing.T) { + ctx := context.Background() + _, err := m.GetAccessTokenSession(ctx, "4321", &fosite.DefaultSession{}) + assert.NotNil(t, err) + + err = m.CreateAccessTokenSession(ctx, "4321", &defaultRequest) + require.NoError(t, err) + + res, err := m.GetAccessTokenSession(ctx, "4321", &fosite.DefaultSession{}) + require.NoError(t, err) + AssertObjectKeysEqual(t, &defaultRequest, res, "Scopes", "GrantedScopes", "Form", "Session") + + err = m.DeleteAccessTokenSession(ctx, "4321") + require.NoError(t, err) + + _, err = m.GetAccessTokenSession(ctx, "4321", &fosite.DefaultSession{}) + assert.NotNil(t, err) + } +} diff --git a/warden/group/manager_test.go b/warden/group/manager_test.go index 5140eb17c91..411d5942ae7 100644 --- a/warden/group/manager_test.go +++ b/warden/group/manager_test.go @@ -17,8 +17,6 @@ import ( "github.com/ory/hydra/integration" . "github.com/ory/hydra/warden/group" "github.com/ory/ladon" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var clientManagers = map[string]Manager{} @@ -88,43 +86,6 @@ func connectToPG() { func TestManagers(t *testing.T) { for k, m := range clientManagers { - t.Run(fmt.Sprintf("case=%s", k), func(t *testing.T) { - _, err := m.GetGroup("4321") - assert.NotNil(t, err) - - c := &Group{ - ID: "1", - Members: []string{"bar", "foo"}, - } - assert.Nil(t, m.CreateGroup(c)) - assert.Nil(t, m.CreateGroup(&Group{ - ID: "2", - Members: []string{"foo"}, - })) - - d, err := m.GetGroup("1") - require.Nil(t, err) - assert.EqualValues(t, c.Members, d.Members) - assert.EqualValues(t, c.ID, d.ID) - - ds, err := m.FindGroupNames("foo") - require.Nil(t, err) - assert.Len(t, ds, 2) - - assert.Nil(t, m.AddGroupMembers("1", []string{"baz"})) - - ds, err = m.FindGroupNames("baz") - require.Nil(t, err) - assert.Len(t, ds, 1) - - assert.Nil(t, m.RemoveGroupMembers("1", []string{"baz"})) - ds, err = m.FindGroupNames("baz") - require.Nil(t, err) - assert.Len(t, ds, 0) - - assert.Nil(t, m.DeleteGroup("1")) - _, err = m.GetGroup("1") - require.NotNil(t, err) - }) + t.Run(fmt.Sprintf("case=%s", k), TestHelperManagers(m)) } } diff --git a/warden/group/manager_test_helper.go b/warden/group/manager_test_helper.go new file mode 100644 index 00000000000..99a0cee4369 --- /dev/null +++ b/warden/group/manager_test_helper.go @@ -0,0 +1,49 @@ +package group + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHelperManagers(m Manager) func(t *testing.T) { + return func(t *testing.T) { + _, err := m.GetGroup("4321") + assert.NotNil(t, err) + + c := &Group{ + ID: "1", + Members: []string{"bar", "foo"}, + } + assert.NoError(t, m.CreateGroup(c)) + assert.NoError(t, m.CreateGroup(&Group{ + ID: "2", + Members: []string{"foo"}, + })) + + d, err := m.GetGroup("1") + require.NoError(t, err) + assert.EqualValues(t, c.Members, d.Members) + assert.EqualValues(t, c.ID, d.ID) + + ds, err := m.FindGroupNames("foo") + require.NoError(t, err) + assert.Len(t, ds, 2) + + assert.NoError(t, m.AddGroupMembers("1", []string{"baz"})) + + ds, err = m.FindGroupNames("baz") + require.NoError(t, err) + assert.Len(t, ds, 1) + + assert.NoError(t, m.RemoveGroupMembers("1", []string{"baz"})) + ds, err = m.FindGroupNames("baz") + require.NoError(t, err) + assert.Len(t, ds, 0) + + assert.NoError(t, m.DeleteGroup("1")) + _, err = m.GetGroup("1") + require.NotNil(t, err) + } +}