diff --git a/kp.go b/kp.go index 3acd82b..a8be552 100644 --- a/kp.go +++ b/kp.go @@ -54,6 +54,8 @@ const ( authContextKey ContextKey = 0 defaultTimeout = 30 // in seconds. + + correlationIdContextKey = "X-Correlation-Id" ) var ( @@ -234,13 +236,14 @@ func (c *Client) do(ctx context.Context, req *http.Request, res interface{}) (*h return nil, err } - // generate our own UUID for the correlation ID and feed it into the request + // retrieve the correlation id from the context. If not present, then a UUID will be + // generated for the correlation ID and feed it into the request // KeyProtect will use this when it is set on a request header rather than generating its // own inside the service - // We generate our own here because a connection error might actually mean the request - // doesn't make it server side, so having a correlation ID locally helps us know that - // when comparing with server side logs. - corrId := uuid.New().String() + // if not present, we generate our own here because a connection error might actually + // mean the request doesn't make it server side, so having a correlation ID locally helps + // us know that when comparing with server side logs. + corrId := c.getCorrelationId(ctx) req.Header.Set("bluemix-instance", c.Config.InstanceID) req.Header.Set("authorization", acccesToken) @@ -365,6 +368,21 @@ func (c *Client) getAccessToken(ctx context.Context) (string, error) { return fmt.Sprintf("%s %s", token.TokenType, token.AccessToken), nil } +// getCorrelationId returns the correlation ID value from the given Context, or +// returns a new UUID if not present +func (c *Client) getCorrelationId(ctx context.Context) string { + if ctx.Value(correlationIdContextKey) != nil { + corrId := ctx.Value(correlationIdContextKey).(string) + _, err := uuid.Parse(corrId) + if err == nil { + return corrId + } + } + + corrId := uuid.New().String() + return corrId +} + // Logger writes when called. type Logger interface { Info(...interface{}) diff --git a/kp_test.go b/kp_test.go index 8292382..bdf1341 100644 --- a/kp_test.go +++ b/kp_test.go @@ -29,6 +29,7 @@ import ( "time" "github.com/IBM/keyprotect-go-client/iam" + "github.com/google/uuid" "github.com/stretchr/testify/assert" gock "gopkg.in/h2non/gock.v1" @@ -1038,6 +1039,43 @@ func TestDo_ConnectionError_HasCorrelationID(t *testing.T) { assert.NotEmpty(t, urlErr.CorrelationID) } +func TestDo_CorrelationID_Set(t *testing.T) { + defer gock.Off() + + gock.New("http://example.com"). + ReplyError(errors.New("test error")) + + c, _, err := NewTestClient(t, nil) + gock.InterceptClient(&c.HttpClient) + defer gock.RestoreClient(&c.HttpClient) + c.tokenSource = &FakeTokenSource{} + + corrId := uuid.New().String() + ctx := context.WithValue(context.Background(), correlationIdContextKey, corrId) + _, err = c.GetKeys(ctx, 0, 0) + assert.Contains(t, err.Error(), "correlation_id='"+corrId+"'") +} + +func TestDo_CorrelationID_NotUUID(t *testing.T) { + defer gock.Off() + + gock.New("http://example.com"). + ReplyError(errors.New("test error")) + + c, _, err := NewTestClient(t, nil) + gock.InterceptClient(&c.HttpClient) + defer gock.RestoreClient(&c.HttpClient) + c.tokenSource = &FakeTokenSource{} + + corrId := "invalid-uuid" + ctx := context.WithValue(context.Background(), correlationIdContextKey, corrId) + _, err = c.GetKeys(ctx, 0, 0) + assert.NotContains(t, err.Error(), corrId) + reasonsErr := err.(*URLError) + _, err = uuid.Parse(reasonsErr.CorrelationID) + assert.NoError(t, err) +} + func TestDo_KPErrorResponseWithReasons_IsErrorStruct(t *testing.T) { defer gock.Off()