diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 0ce3f36d3..ced713d1a 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,8 +17,11 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/assert" + sassert "github.com/smallstep/assert" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/webhook" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" @@ -94,19 +97,24 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) + sassert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } +// withRequestID is a helper that calls into [logging.WithRequestID] and returns +// a new context with the requestID added to the provided context. +func withRequestID(ctx context.Context, requestID string) context.Context { + return logging.WithRequestID(ctx, requestID) +} + func TestWebhookController_Enrich(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -131,6 +139,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -145,6 +154,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -168,6 +178,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -187,14 +198,15 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + sassert.FatalError(t, err) + sassert.Equals(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -209,6 +221,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -223,6 +236,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -234,19 +248,21 @@ func TestWebhookController_Enrich(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Enrich(context.Background(), test.req) + err := test.ctl.Enrich(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } - assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) + sassert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } @@ -256,12 +272,11 @@ func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -282,6 +297,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -292,6 +308,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -302,13 +319,14 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + sassert.Equals(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -322,6 +340,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -334,6 +353,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -344,15 +364,17 @@ func TestWebhookController_Authorize(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Authorize(context.Background(), test.req) + err := test.ctl.Authorize(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -368,6 +390,7 @@ func TestWebhook_Do(t *testing.T) { type test struct { webhook Webhook dataArg any + requestID string webhookResponse webhook.ResponseBody expectPath string errStatusCode int @@ -377,6 +400,16 @@ func TestWebhook_Do(t *testing.T) { } tests := map[string]test{ "ok": { + webhook: Webhook{ + ID: "abc123", + Secret: "c2VjcmV0Cg==", + }, + requestID: "reqID", + webhookResponse: webhook.ResponseBody{ + Data: map[string]interface{}{"role": "dba"}, + }, + }, + "ok/no-request-id": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", @@ -391,6 +424,7 @@ func TestWebhook_Do(t *testing.T) { Secret: "c2VjcmV0Cg==", BearerToken: "mytoken", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -407,6 +441,7 @@ func TestWebhook_Do(t *testing.T) { Password: "mypass", }, }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -418,7 +453,8 @@ func TestWebhook_Do(t *testing.T) { URL: "/users/{{ .username }}?region={{ .region }}", Secret: "c2VjcmV0Cg==", }, - dataArg: map[string]interface{}{"username": "areed", "region": "central"}, + requestID: "reqID", + dataArg: map[string]interface{}{"username": "areed", "region": "central"}, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -453,6 +489,7 @@ func TestWebhook_Do(t *testing.T) { ID: "abc123", Secret: "c2VjcmV0Cg==", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Allow: true, }, @@ -465,6 +502,7 @@ func TestWebhook_Do(t *testing.T) { webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, + requestID: "reqID", errStatusCode: 404, serverErrMsg: "item not found", expectErr: errors.New("Webhook server responded with 404"), @@ -473,38 +511,42 @@ func TestWebhook_Do(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tc.requestID != "" { + assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) + } + id := r.Header.Get("X-Smallstep-Webhook-ID") - assert.Equals(t, tc.webhook.ID, id) + sassert.Equals(t, tc.webhook.ID, id) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) - assert.FatalError(t, err) + assert.NoError(t, err) body, err := io.ReadAll(r.Body) - assert.FatalError(t, err) + assert.NoError(t, err) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) - assert.FatalError(t, err) + assert.NoError(t, err) h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) - assert.True(t, hmac.Equal(sig, mac)) + sassert.True(t, hmac.Equal(sig, mac)) switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) - assert.Equals(t, ah, r.Header.Get("Authorization")) + sassert.Equals(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) - assert.FatalError(t, err) + assert.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") - assert.Equals(t, ah, whReq.Header.Get("Authorization")) + sassert.Equals(t, ah, whReq.Header.Get("Authorization")) default: - assert.Equals(t, "", r.Header.Get("Authorization")) + sassert.Equals(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { - assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) + sassert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { @@ -514,30 +556,34 @@ func TestWebhook_Do(t *testing.T) { reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) - assert.FatalError(t, err) - // assert.Equals(t, tc.expectToken, reqBody.Token) + require.NoError(t, err) + // sassert.Equals(t, tc.expectToken, reqBody.Token) err = json.NewEncoder(w).Encode(tc.webhookResponse) - assert.FatalError(t, err) + require.NoError(t, err) })) defer ts.Close() tc.webhook.URL = ts.URL + tc.webhook.URL reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) + require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx := context.Background() + if tc.requestID != "" { + ctx = withRequestID(context.Background(), tc.requestID) + } + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { - assert.Equals(t, tc.expectErr.Error(), err.Error()) + sassert.Equals(t, tc.expectErr.Error(), err.Error()) return } - assert.FatalError(t, err) + assert.NoError(t, err) - assert.Equals(t, got, &tc.webhookResponse) + sassert.Equals(t, got, &tc.webhookResponse) }) } @@ -550,7 +596,7 @@ func TestWebhook_Do(t *testing.T) { URL: ts.URL, } cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") - assert.FatalError(t, err) + require.NoError(t, err) transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, @@ -560,19 +606,19 @@ func TestWebhook_Do(t *testing.T) { Transport: transport, } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() _, err = wh.DoWithContext(ctx, client, reqBody, nil) - assert.FatalError(t, err) + require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) defer cancel() wh.DisableTLSClientAuth = true _, err = wh.DoWithContext(ctx, client, reqBody, nil) - assert.Error(t, err) + require.Error(t, err) }) }