Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

propagate context through to the sql store #1030

Merged
merged 8 commits into from
Sep 18, 2018
10 changes: 5 additions & 5 deletions client/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request, _ httprouter.Pa
}

secret := c.Secret
if err := h.Manager.CreateClient(&c); err != nil {
if err := h.Manager.CreateClient(r.Context(), &c); err != nil {
h.H.WriteError(w, r, err)
return
}
Expand Down Expand Up @@ -159,7 +159,7 @@ func (h *Handler) Update(w http.ResponseWriter, r *http.Request, ps httprouter.P
return
}

if err := h.Manager.UpdateClient(&c); err != nil {
if err := h.Manager.UpdateClient(r.Context(), &c); err != nil {
h.H.WriteError(w, r, err)
return
}
Expand Down Expand Up @@ -191,7 +191,7 @@ func (h *Handler) Update(w http.ResponseWriter, r *http.Request, ps httprouter.P
// 500: genericError
func (h *Handler) List(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
limit, offset := pagination.Parse(r, 100, 0, 500)
c, err := h.Manager.GetClients(limit, offset)
c, err := h.Manager.GetClients(r.Context(), limit, offset)
if err != nil {
h.H.WriteError(w, r, err)
return
Expand Down Expand Up @@ -233,7 +233,7 @@ func (h *Handler) List(w http.ResponseWriter, r *http.Request, ps httprouter.Par
func (h *Handler) Get(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
var id = ps.ByName("id")

c, err := h.Manager.GetConcreteClient(id)
c, err := h.Manager.GetConcreteClient(r.Context(), id)
if err != nil {
h.H.WriteError(w, r, err)
return
Expand Down Expand Up @@ -267,7 +267,7 @@ func (h *Handler) Get(w http.ResponseWriter, r *http.Request, ps httprouter.Para
func (h *Handler) Delete(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
var id = ps.ByName("id")

if err := h.Manager.DeleteClient(id); err != nil {
if err := h.Manager.DeleteClient(r.Context(), id); err != nil {
h.H.WriteError(w, r, err)
return
}
Expand Down
14 changes: 8 additions & 6 deletions client/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,27 @@
package client

import (
"context"

"github.com/ory/fosite"
)

type Manager interface {
Storage

Authenticate(id string, secret []byte) (*Client, error)
Authenticate(ctx context.Context, id string, secret []byte) (*Client, error)
}

type Storage interface {
fosite.Storage

CreateClient(c *Client) error
CreateClient(ctx context.Context, c *Client) error

UpdateClient(c *Client) error
UpdateClient(ctx context.Context, c *Client) error

DeleteClient(id string) error
DeleteClient(ctx context.Context, id string) error

GetClients(limit, offset int) (map[string]Client, error)
GetClients(ctx context.Context, limit, offset int) (map[string]Client, error)

GetConcreteClient(id string) (*Client, error)
GetConcreteClient(ctx context.Context, id string) (*Client, error)
}
4 changes: 3 additions & 1 deletion client/manager_0_sql_migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"sync"
"testing"

"context"

"github.com/jmoiron/sqlx"
"github.com/ory/fosite"
"github.com/ory/hydra/client"
Expand Down Expand Up @@ -197,7 +199,7 @@ func TestMigrations(t *testing.T) {
for _, key := range []string{"1-data", "2-data", "3-data", "4-data", "5-data"} {
t.Run("client="+key, func(t *testing.T) {
s := &client.SQLManager{DB: db, Hasher: &fosite.BCrypt{WorkFactor: 4}}
c, err := s.GetConcreteClient(key)
c, err := s.GetConcreteClient(context.TODO(), key)
require.NoError(t, err)
assert.EqualValues(t, c.GetID(), key)
})
Expand Down
22 changes: 11 additions & 11 deletions client/manager_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewMemoryManager(hasher fosite.Hasher) *MemoryManager {
}
}

func (m *MemoryManager) GetConcreteClient(id string) (*Client, error) {
func (m *MemoryManager) GetConcreteClient(ctx context.Context, id string) (*Client, error) {
m.RLock()
defer m.RUnlock()

Expand All @@ -61,12 +61,12 @@ func (m *MemoryManager) GetConcreteClient(id string) (*Client, error) {
return nil, errors.WithStack(sqlcon.ErrNoRows)
}

func (m *MemoryManager) GetClient(_ context.Context, id string) (fosite.Client, error) {
return m.GetConcreteClient(id)
func (m *MemoryManager) GetClient(ctx context.Context, id string) (fosite.Client, error) {
return m.GetConcreteClient(ctx, id)
}

func (m *MemoryManager) UpdateClient(c *Client) error {
o, err := m.GetClient(context.Background(), c.GetID())
func (m *MemoryManager) UpdateClient(ctx context.Context, c *Client) error {
o, err := m.GetClient(ctx, c.GetID())
if err != nil {
return err
}
Expand Down Expand Up @@ -95,11 +95,11 @@ func (m *MemoryManager) UpdateClient(c *Client) error {
return nil
}

func (m *MemoryManager) Authenticate(id string, secret []byte) (*Client, error) {
func (m *MemoryManager) Authenticate(ctx context.Context, id string, secret []byte) (*Client, error) {
m.RLock()
defer m.RUnlock()

c, err := m.GetConcreteClient(id)
c, err := m.GetConcreteClient(ctx, id)
if err != nil {
return nil, err
}
Expand All @@ -111,8 +111,8 @@ func (m *MemoryManager) Authenticate(id string, secret []byte) (*Client, error)
return c, nil
}

func (m *MemoryManager) CreateClient(c *Client) error {
if _, err := m.GetConcreteClient(c.GetID()); err == nil {
func (m *MemoryManager) CreateClient(ctx context.Context, c *Client) error {
if _, err := m.GetConcreteClient(ctx, c.GetID()); err == nil {
return sqlcon.ErrUniqueViolation
}

Expand All @@ -129,7 +129,7 @@ func (m *MemoryManager) CreateClient(c *Client) error {
return nil
}

func (m *MemoryManager) DeleteClient(id string) error {
func (m *MemoryManager) DeleteClient(ctx context.Context, id string) error {
m.Lock()
defer m.Unlock()

Expand All @@ -143,7 +143,7 @@ func (m *MemoryManager) DeleteClient(id string) error {
return nil
}

func (m *MemoryManager) GetClients(limit, offset int) (clients map[string]Client, err error) {
func (m *MemoryManager) GetClients(ctx context.Context, limit, offset int) (clients map[string]Client, err error) {
m.RLock()
defer m.RUnlock()
clients = make(map[string]Client)
Expand Down
18 changes: 9 additions & 9 deletions client/manager_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func (m *SQLManager) CreateSchemas() (int, error) {
return n, nil
}

func (m *SQLManager) GetConcreteClient(id string) (*Client, error) {
func (m *SQLManager) GetConcreteClient(ctx context.Context, id string) (*Client, error) {
var d sqlData
if err := m.DB.Get(&d, m.DB.Rebind("SELECT * FROM hydra_client WHERE id=?"), id); err != nil {
return nil, sqlcon.HandleError(err)
Expand All @@ -345,11 +345,11 @@ func (m *SQLManager) GetConcreteClient(id string) (*Client, error) {
return d.ToClient()
}

func (m *SQLManager) GetClient(_ context.Context, id string) (fosite.Client, error) {
return m.GetConcreteClient(id)
func (m *SQLManager) GetClient(ctx context.Context, id string) (fosite.Client, error) {
return m.GetConcreteClient(ctx, id)
}

func (m *SQLManager) UpdateClient(c *Client) error {
func (m *SQLManager) UpdateClient(ctx context.Context, c *Client) error {
o, err := m.GetClient(context.Background(), c.GetID())
if err != nil {
return errors.WithStack(err)
Expand Down Expand Up @@ -381,8 +381,8 @@ func (m *SQLManager) UpdateClient(c *Client) error {
return nil
}

func (m *SQLManager) Authenticate(id string, secret []byte) (*Client, error) {
c, err := m.GetConcreteClient(id)
func (m *SQLManager) Authenticate(ctx context.Context, id string, secret []byte) (*Client, error) {
c, err := m.GetConcreteClient(ctx, id)
if err != nil {
return nil, errors.WithStack(err)
}
Expand All @@ -394,7 +394,7 @@ func (m *SQLManager) Authenticate(id string, secret []byte) (*Client, error) {
return c, nil
}

func (m *SQLManager) CreateClient(c *Client) error {
func (m *SQLManager) CreateClient(ctx context.Context, c *Client) error {
h, err := m.Hasher.Hash([]byte(c.Secret))
if err != nil {
return errors.WithStack(err)
Expand All @@ -417,14 +417,14 @@ func (m *SQLManager) CreateClient(c *Client) error {
return nil
}

func (m *SQLManager) DeleteClient(id string) error {
func (m *SQLManager) DeleteClient(ctx context.Context, id string) error {
if _, err := m.DB.Exec(m.DB.Rebind(`DELETE FROM hydra_client WHERE id=?`), id); err != nil {
return sqlcon.HandleError(err)
}
return nil
}

func (m *SQLManager) GetClients(limit, offset int) (clients map[string]Client, err error) {
func (m *SQLManager) GetClients(ctx context.Context, limit, offset int) (clients map[string]Client, err error) {
d := make([]sqlData, 0)
clients = make(map[string]Client)

Expand Down
32 changes: 19 additions & 13 deletions client/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"crypto/x509"
"testing"

"context"

"github.com/ory/fosite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -32,32 +34,34 @@ import (

func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.TODO()
t.Parallel()
c := &Client{
ClientID: "foo",
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
TermsOfServiceURI: "foo",
}
assert.NoError(t, m.CreateClient(c))
assert.NoError(t, m.CreateClient(ctx, c))
//assert.NotEmpty(t, c.ID)
assert.NoError(t, m.DeleteClient(c.GetID()))
assert.NoError(t, m.DeleteClient(ctx, c.GetID()))
}
}

func TestHelperClientAuthenticate(k string, m Manager) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.TODO()
t.Parallel()
m.CreateClient(&Client{
m.CreateClient(ctx, &Client{
ClientID: "1234321",
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
})

c, err := m.Authenticate("1234321", []byte("secret1"))
c, err := m.Authenticate(ctx, "1234321", []byte("secret1"))
require.NotNil(t, err)

c, err = m.Authenticate("1234321", []byte("secret"))
c, err = m.Authenticate(ctx, "1234321", []byte("secret"))
require.NoError(t, err)
assert.Equal(t, "1234321", c.GetID())
}
Expand All @@ -69,6 +73,8 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) {
_, err := m.GetClient(nil, "4321")
assert.NotNil(t, err)

ctx := context.TODO()

c := &Client{
ClientID: "1234",
Name: "name",
Expand All @@ -94,13 +100,13 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) {
UserinfoSignedResponseAlg: "RS256",
}

assert.NoError(t, m.CreateClient(c))
assert.NoError(t, m.CreateClient(ctx, c))
assert.Equal(t, c.GetID(), "1234")
if k != "http" {
assert.NotEmpty(t, c.GetHashedSecret())
}

assert.NoError(t, m.CreateClient(&Client{
assert.NoError(t, m.CreateClient(ctx, &Client{
ClientID: "2-1234",
Name: "name",
Secret: "secret",
Expand All @@ -114,7 +120,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) {

compare(t, c, d, k)

ds, err := m.GetClients(100, 0)
ds, err := m.GetClients(ctx, 100, 0)
assert.NoError(t, err)
assert.Len(t, ds, 2)
assert.NotEqual(t, ds["1234"].ClientID, ds["2-1234"].ClientID)
Expand All @@ -124,15 +130,15 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) {
assert.Equal(t, ds["1234"].SecretExpiresAt, 0)
assert.Equal(t, ds["2-1234"].SecretExpiresAt, 1)

ds, err = m.GetClients(1, 0)
ds, err = m.GetClients(ctx, 1, 0)
assert.NoError(t, err)
assert.Len(t, ds, 1)

ds, err = m.GetClients(100, 100)
ds, err = m.GetClients(ctx, 100, 100)
assert.NoError(t, err)
assert.Len(t, ds, 0)

err = m.UpdateClient(&Client{
err = m.UpdateClient(ctx, &Client{
ClientID: "2-1234",
Name: "name-new",
Secret: "secret-new",
Expand All @@ -141,7 +147,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) {
})
require.NoError(t, err)

nc, err := m.GetConcreteClient("2-1234")
nc, err := m.GetConcreteClient(ctx, "2-1234")
require.NoError(t, err)

if k != "http" {
Expand All @@ -153,7 +159,7 @@ func TestHelperCreateGetDeleteClient(k string, m Storage) func(t *testing.T) {
assert.EqualValues(t, []string{"http://redirect/new"}, nc.GetRedirectURIs())
assert.Zero(t, len(nc.Contacts))

err = m.DeleteClient("1234")
err = m.DeleteClient(ctx, "1234")
assert.NoError(t, err)

_, err = m.GetClient(nil, "1234")
Expand Down
6 changes: 4 additions & 2 deletions cmd/server/helper_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"strings"
"time"

"context"

"github.com/ory/hydra/config"
"github.com/ory/hydra/jwk"
"github.com/ory/hydra/pkg"
Expand Down Expand Up @@ -105,10 +107,10 @@ func getOrCreateTLSCertificate(cmd *cobra.Command, c *config.Config) tls.Certifi
}

privateKey.Certificates = []*x509.Certificate{cert}
if err := ctx.KeyManager.DeleteKey(tlsKeyName, privateKey.KeyID); err != nil {
if err := ctx.KeyManager.DeleteKey(context.TODO(), tlsKeyName, privateKey.KeyID); err != nil {
c.GetLogger().WithError(err).Fatalf(`Could not update (delete) the self signed TLS certificate.`)
}
if err := ctx.KeyManager.AddKey(tlsKeyName, privateKey); err != nil {
if err := ctx.KeyManager.AddKey(context.TODO(), tlsKeyName, privateKey); err != nil {
c.GetLogger().WithError(err).Fatalf(`Could not update (add) the self signed TLS certificate.`)
}
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/server/helper_cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
func newCORSMiddleware(
enable bool, c *config.Config,
o func(ctx context.Context, token string, tokenType fosite.TokenType, session fosite.Session, scope ...string) (fosite.TokenType, fosite.AccessRequester, error),
clm func(id string) (*client.Client, error),
clm func(ctx context.Context, id string) (*client.Client, error),
) func(h http.Handler) http.Handler {
if !enable {
return func(h http.Handler) http.Handler {
Expand Down Expand Up @@ -76,7 +76,7 @@ func newCORSMiddleware(
username = ar.GetClient().GetID()
}

cl, err := clm(username)
cl, err := clm(r.Context(), username)
if err != nil {
return false
}
Expand Down
Loading