diff --git a/components/public-api-server/pkg/server/server.go b/components/public-api-server/pkg/server/server.go index 0b96e4e565c898..9345d624c17a11 100644 --- a/components/public-api-server/pkg/server/server.go +++ b/components/public-api-server/pkg/server/server.go @@ -5,8 +5,13 @@ package server import ( + "encoding/json" "fmt" + "github.com/gitpod-io/gitpod/common-go/log" + "net/http" "net/url" + "os" + "strings" "github.com/gitpod-io/gitpod/public-api/config" "github.com/gorilla/handlers" @@ -48,9 +53,18 @@ func Start(logger *logrus.Entry, cfg *config.Configuration) error { } } - srv.HTTPMux().Handle("/stripe/invoices/webhook", - handlers.ContentTypeHandler(webhooks.NewStripeWebhookHandler(billingService), "application/json"), - ) + var stripeWebhookHandler http.Handler = webhooks.NewNoopWebhookHandler() + if cfg.StripeWebhookSigningSecretPath != "" { + stripeWebhookSecret, err := readStripeWebhookSecret(cfg.StripeWebhookSigningSecretPath) + if err != nil { + return fmt.Errorf("failed to read stripe secret: %w", err) + } + stripeWebhookHandler = webhooks.NewStripeWebhookHandler(billingService, stripeWebhookSecret) + } else { + log.Info("No stripe webhook secret is configured, endpoints will return NotImplemented") + } + + srv.HTTPMux().Handle("/stripe/invoices/webhook", handlers.ContentTypeHandler(stripeWebhookHandler, "application/json")) if registerErr := register(srv, gitpodAPI, registry); registerErr != nil { return fmt.Errorf("failed to register services: %w", registerErr) @@ -73,3 +87,18 @@ func register(srv *baseserver.Server, serverAPIURL *url.URL, registry *prometheu return nil } + +func readStripeWebhookSecret(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read stripe webhook secret: %w", err) + } + + var stripeSecret string + err = json.Unmarshal(b, &stripeSecret) + if err != nil { + return "", fmt.Errorf("failed to parse stripe webhook secret: %w", err) + } + + return strings.TrimSpace(stripeSecret), nil +} diff --git a/components/public-api-server/pkg/webhooks/stripe.go b/components/public-api-server/pkg/webhooks/stripe.go index 7668b6b3f39515..112971d7b135a1 100644 --- a/components/public-api-server/pkg/webhooks/stripe.go +++ b/components/public-api-server/pkg/webhooks/stripe.go @@ -5,40 +5,53 @@ package webhooks import ( - "encoding/json" - "net/http" - "github.com/gitpod-io/gitpod/common-go/log" "github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice" - "github.com/stripe/stripe-go/v72" + "github.com/stripe/stripe-go/v72/webhook" + "io" + "net/http" ) +const maxBodyBytes = int64(65536) + type webhookHandler struct { - billingService billingservice.Interface + billingService billingservice.Interface + stripeWebhookSignature string } -func NewStripeWebhookHandler(billingService billingservice.Interface) *webhookHandler { - return &webhookHandler{billingService: billingService} +func NewStripeWebhookHandler(billingService billingservice.Interface, stripeWebhookSignature string) *webhookHandler { + return &webhookHandler{ + billingService: billingService, + stripeWebhookSignature: stripeWebhookSignature, + } } func (h *webhookHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - const maxBodyBytes = int64(65536) - if req.Method != http.MethodPost { log.Errorf("Bad HTTP method: %s", req.Method) - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusMethodNotAllowed) return } - // TODO: verify webhook signature. - // Conditional on there being a secret configured. + stripeSignature := req.Header.Get("Stripe-Signature") + if stripeSignature == "" { + w.WriteHeader(http.StatusBadRequest) + return + } req.Body = http.MaxBytesReader(w, req.Body, maxBodyBytes) - event := stripe.Event{} - err := json.NewDecoder(req.Body).Decode(&event) + payload, err := io.ReadAll(req.Body) + if err != nil { + log.WithError(err).Error("Failed to read payload body.") + w.WriteHeader(http.StatusBadRequest) + return + } + + // https://stripe.com/docs/webhooks/signatures#verify-official-libraries + event, err := webhook.ConstructEvent(payload, req.Header.Get("Stripe-Signature"), h.stripeWebhookSignature) if err != nil { - log.WithError(err).Error("Stripe webhook error while parsing event payload") + log.WithError(err).Error("Failed to verify webhook signature.") w.WriteHeader(http.StatusBadRequest) return } diff --git a/components/public-api-server/pkg/webhooks/stripe_noop.go b/components/public-api-server/pkg/webhooks/stripe_noop.go new file mode 100644 index 00000000000000..2ad43bbfc1cece --- /dev/null +++ b/components/public-api-server/pkg/webhooks/stripe_noop.go @@ -0,0 +1,21 @@ +// Copyright (c) 2022 Gitpod GmbH. All rights reserved. +// Licensed under the GNU Affero General Public License (AGPL). +// See License-AGPL.txt in the project root for license information. + +package webhooks + +import ( + "github.com/gitpod-io/gitpod/common-go/log" + "net/http" +) + +func NewNoopWebhookHandler() *noopWebhookHandler { + return &noopWebhookHandler{} +} + +type noopWebhookHandler struct{} + +func (h *noopWebhookHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + log.Info("Received Stripe webhook handler, but running in no-op mode so will not be handing it.") + w.WriteHeader(http.StatusNotImplemented) +} diff --git a/components/public-api-server/pkg/webhooks/stripe_test.go b/components/public-api-server/pkg/webhooks/stripe_test.go index e1e6dee6370ad7..57c49bf7130d96 100644 --- a/components/public-api-server/pkg/webhooks/stripe_test.go +++ b/components/public-api-server/pkg/webhooks/stripe_test.go @@ -5,11 +5,13 @@ package webhooks import ( + "bytes" + "encoding/hex" "fmt" - "io" + "github.com/stripe/stripe-go/v72/webhook" "net/http" - "strings" "testing" + "time" "github.com/gitpod-io/gitpod/common-go/baseserver" "github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice" @@ -25,6 +27,10 @@ const ( customerCreatedEventType = "customer.created" ) +const ( + testWebhookSecret = "whsec_random_secret" +) + func TestWebhookAcceptsPostRequests(t *testing.T) { scenarios := []struct { HttpMethod string @@ -36,11 +42,11 @@ func TestWebhookAcceptsPostRequests(t *testing.T) { }, { HttpMethod: http.MethodGet, - ExpectedStatusCode: http.StatusBadRequest, + ExpectedStatusCode: http.StatusMethodNotAllowed, }, { HttpMethod: http.MethodPut, - ExpectedStatusCode: http.StatusBadRequest, + ExpectedStatusCode: http.StatusMethodNotAllowed, }, } @@ -52,9 +58,11 @@ func TestWebhookAcceptsPostRequests(t *testing.T) { for _, scenario := range scenarios { t.Run(scenario.HttpMethod, func(t *testing.T) { - req, err := http.NewRequest(scenario.HttpMethod, url, payload) + req, err := http.NewRequest(scenario.HttpMethod, url, bytes.NewReader(payload)) require.NoError(t, err) + req.Header.Set("Stripe-Signature", generateHeader(payload, testWebhookSecret)) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -89,9 +97,12 @@ func TestWebhookIgnoresIrrelevantEvents(t *testing.T) { for _, scenario := range scenarios { t.Run(scenario.EventType, func(t *testing.T) { payload := payloadForStripeEvent(t, scenario.EventType) - req, err := http.NewRequest(http.MethodPost, url, payload) + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(payload)) require.NoError(t, err) + req.Header.Set("Stripe-Signature", generateHeader(payload, testWebhookSecret)) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -113,9 +124,11 @@ func TestWebhookInvokesFinalizeInvoiceRPC(t *testing.T) { url := fmt.Sprintf("%s%s", srv.HTTPAddress(), "/webhook") payload := payloadForStripeEvent(t, invoiceFinalizedEventType) - req, err := http.NewRequest(http.MethodPost, url, payload) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(payload)) require.NoError(t, err) + req.Header.Set("Stripe-Signature", generateHeader(payload, testWebhookSecret)) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -129,18 +142,18 @@ func baseServerWithStripeWebhook(t *testing.T, billingService billingservice.Int ) baseserver.StartServerForTests(t, srv) - srv.HTTPMux().Handle("/webhook", NewStripeWebhookHandler(billingService)) + srv.HTTPMux().Handle("/webhook", NewStripeWebhookHandler(billingService, testWebhookSecret)) return srv } -func payloadForStripeEvent(t *testing.T, eventType string) io.Reader { +func payloadForStripeEvent(t *testing.T, eventType string) []byte { t.Helper() if eventType != invoiceFinalizedEventType { - return strings.NewReader(`{}`) + return []byte(`{}`) } - return strings.NewReader(` + return []byte(` { "data": { "object": { @@ -151,3 +164,9 @@ func payloadForStripeEvent(t *testing.T, eventType string) io.Reader { } `) } + +func generateHeader(payload []byte, secret string) string { + now := time.Now() + signature := webhook.ComputeSignature(now, payload, secret) + return fmt.Sprintf("t=%d,%s=%s", now.Unix(), "v1", hex.EncodeToString(signature)) +}