From b0d176728aaaea8d9ef0aeb1c7a2a312a9f34342 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 25 May 2022 15:50:12 +0200 Subject: [PATCH 1/2] Change the http verb of /stripe/billing from POST to GET. --- api/routes.go | 2 +- api/stripe.go | 4 ++-- test/api/stripe_test.go | 10 +++++----- test/database/user_test.go | 2 +- test/tester.go | 6 +++--- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/api/routes.go b/api/routes.go index 0692dfb2..41d3aca6 100644 --- a/api/routes.go +++ b/api/routes.go @@ -77,7 +77,7 @@ func (api *API) buildHTTPRoutes() { api.staticRouter.POST("/user/recover/request", api.WithDBSession(api.noAuth(api.userRecoverRequestPOST))) api.staticRouter.POST("/user/recover", api.WithDBSession(api.noAuth(api.userRecoverPOST))) - api.staticRouter.POST("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingPOST, false))) + api.staticRouter.GET("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingGET, false))) api.staticRouter.POST("/stripe/checkout", api.WithDBSession(api.withAuth(api.stripeCheckoutPOST, false))) api.staticRouter.GET("/stripe/prices", api.noAuth(api.stripePricesGET)) api.staticRouter.POST("/stripe/webhook", api.WithDBSession(api.noAuth(api.stripeWebhookPOST))) diff --git a/api/stripe.go b/api/stripe.go index 965c8482..a7311381 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -152,10 +152,10 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er return err } -// stripeBillingPOST creates a new billing session for the user and redirects +// stripeBillingGET creates a new billing session for the user and redirects // them to it. If the user does not yet have a Stripe customer, one is // registered for them. -func (api *API) stripeBillingPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { +func (api *API) stripeBillingGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { if u.StripeID == "" { id, err := api.stripeCreateCustomer(req.Context(), u) if err != nil { diff --git a/test/api/stripe_test.go b/test/api/stripe_test.go index 25ede319..738cbb6a 100644 --- a/test/api/stripe_test.go +++ b/test/api/stripe_test.go @@ -35,7 +35,7 @@ func TestStripe(t *testing.T) { api.StripeTestMode = true tests := map[string]func(t *testing.T, at *test.AccountsTester){ - "post billing": testStripeBillingPOST, + "post billing": testStripeBillingGET, "get prices": testStripePricesGET, "post checkout": testStripeCheckoutPOST, } @@ -52,8 +52,8 @@ func TestStripe(t *testing.T) { } } -// testStripeBillingPOST ensures that we can create a new billing session. -func testStripeBillingPOST(t *testing.T, at *test.AccountsTester) { +// testStripeBillingGET ensures that we can create a new billing session. +func testStripeBillingGET(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) r, _, err := at.UserPOST(name+"@siasky.net", name+"pass") if err != nil { @@ -65,7 +65,7 @@ func testStripeBillingPOST(t *testing.T, at *test.AccountsTester) { // Try to start a billing session without valid user auth. at.ClearCredentials() - _, s, err := at.StripeBillingPOST() + _, s, err := at.StripeBillingGET() if err == nil || s != http.StatusUnauthorized { t.Fatalf("Expected 401 Unauthorized, got %d %s", s, err) } @@ -73,7 +73,7 @@ func testStripeBillingPOST(t *testing.T, at *test.AccountsTester) { // fail case, we expect that to happen. In production we'll follow that // redirect. at.SetCookie(c) - h, s, err := at.StripeBillingPOST() + h, s, err := at.StripeBillingGET() if err != nil || s != http.StatusTemporaryRedirect { t.Fatalf("Expected %d and no error, got %d '%s'", http.StatusTemporaryRedirect, s, err) } diff --git a/test/database/user_test.go b/test/database/user_test.go index 1926f1f6..7918426f 100644 --- a/test/database/user_test.go +++ b/test/database/user_test.go @@ -265,7 +265,7 @@ func TestUserConfirmEmail(t *testing.T) { t.Fatal("Failed to generate a token.") } // Set the expiration of the token in the past. - u.EmailConfirmationTokenExpiration = time.Now().UTC().Add(-time.Minute) + u.EmailConfirmationTokenExpiration = time.Now().UTC().Add(-time.Minute).Truncate(time.Millisecond) err = db.UserSave(ctx, u) if err != nil { t.Fatal("Failed to save the user:", err) diff --git a/test/tester.go b/test/tester.go index c6bae60a..1c3ffdf2 100644 --- a/test/tester.go +++ b/test/tester.go @@ -668,9 +668,9 @@ func (at *AccountsTester) UploadInfo(sl string) ([]api.UploadInfo, int, error) { /*** Stripe helpers ***/ -// StripeBillingPOST performs a `POST /stripe/billing` -func (at *AccountsTester) StripeBillingPOST() (http.Header, int, error) { - r, err := at.Request(http.MethodPost, "/stripe/billing", nil, nil, nil, nil) +// StripeBillingGET performs a `GET /stripe/billing` +func (at *AccountsTester) StripeBillingGET() (http.Header, int, error) { + r, err := at.Request(http.MethodGet, "/stripe/billing", nil, nil, nil, nil) // We ignore the temporary redirect error because it's the expected // behaviour of this endpoint. if err != nil && !strings.Contains(err.Error(), "307 Temporary Redirect") { From 9c5297f0fc1eb17f2abf4885661697bcfa26d007 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 25 May 2022 16:55:30 +0200 Subject: [PATCH 2/2] Put back the `POST /stripe/billing` route and point it to the same handler. --- api/routes.go | 4 +++- api/stripe.go | 4 ++-- test/api/stripe_test.go | 34 +++++++++++++++++++++++++++++++++- test/tester.go | 11 +++++++++++ 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/api/routes.go b/api/routes.go index 41d3aca6..dfcda593 100644 --- a/api/routes.go +++ b/api/routes.go @@ -77,7 +77,9 @@ func (api *API) buildHTTPRoutes() { api.staticRouter.POST("/user/recover/request", api.WithDBSession(api.noAuth(api.userRecoverRequestPOST))) api.staticRouter.POST("/user/recover", api.WithDBSession(api.noAuth(api.userRecoverPOST))) - api.staticRouter.GET("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingGET, false))) + api.staticRouter.GET("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingHANDLER, false))) + // `POST /stripe/billing` is deprecated. Please use `GET /stripe/billing`. + api.staticRouter.POST("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingHANDLER, false))) api.staticRouter.POST("/stripe/checkout", api.WithDBSession(api.withAuth(api.stripeCheckoutPOST, false))) api.staticRouter.GET("/stripe/prices", api.noAuth(api.stripePricesGET)) api.staticRouter.POST("/stripe/webhook", api.WithDBSession(api.noAuth(api.stripeWebhookPOST))) diff --git a/api/stripe.go b/api/stripe.go index a7311381..fc5cce11 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -152,10 +152,10 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er return err } -// stripeBillingGET creates a new billing session for the user and redirects +// stripeBillingHANDLER creates a new billing session for the user and redirects // them to it. If the user does not yet have a Stripe customer, one is // registered for them. -func (api *API) stripeBillingGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { +func (api *API) stripeBillingHANDLER(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { if u.StripeID == "" { id, err := api.stripeCreateCustomer(req.Context(), u) if err != nil { diff --git a/test/api/stripe_test.go b/test/api/stripe_test.go index 738cbb6a..9ec52465 100644 --- a/test/api/stripe_test.go +++ b/test/api/stripe_test.go @@ -35,7 +35,8 @@ func TestStripe(t *testing.T) { api.StripeTestMode = true tests := map[string]func(t *testing.T, at *test.AccountsTester){ - "post billing": testStripeBillingGET, + "get billing": testStripeBillingGET, + "post billing": testStripeBillingPOST, "get prices": testStripePricesGET, "post checkout": testStripeCheckoutPOST, } @@ -83,6 +84,37 @@ func testStripeBillingGET(t *testing.T, at *test.AccountsTester) { } } +// testStripeBillingPOST ensures that we can create a new billing session. +func testStripeBillingPOST(t *testing.T, at *test.AccountsTester) { + name := test.DBNameForTest(t.Name()) + r, _, err := at.UserPOST(name+"@siasky.net", name+"pass") + if err != nil { + t.Fatal(err) + } + c := test.ExtractCookie(r) + + at.SetFollowRedirects(false) + + // Try to start a billing session without valid user auth. + at.ClearCredentials() + _, s, err := at.StripeBillingPOST() + if err == nil || s != http.StatusUnauthorized { + t.Fatalf("Expected 401 Unauthorized, got %d %s", s, err) + } + // Try with a valid user. Expect a temporary redirect error. This is not a + // fail case, we expect that to happen. In production we'll follow that + // redirect. + at.SetCookie(c) + h, s, err := at.StripeBillingPOST() + if err != nil || s != http.StatusTemporaryRedirect { + t.Fatalf("Expected %d and no error, got %d '%s'", http.StatusTemporaryRedirect, s, err) + } + expectedRedirectPrefix := "https://billing.stripe.com/session/" + if !strings.HasPrefix(h.Get("Location"), expectedRedirectPrefix) { + t.Fatalf("Expected a redirect link with prefix '%s', got '%s'", expectedRedirectPrefix, h.Get("Location")) + } +} + // testStripeCheckoutPOST ensures that we can create a new checkout session. func testStripeCheckoutPOST(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) diff --git a/test/tester.go b/test/tester.go index 1c3ffdf2..999cc12d 100644 --- a/test/tester.go +++ b/test/tester.go @@ -679,6 +679,17 @@ func (at *AccountsTester) StripeBillingGET() (http.Header, int, error) { return r.Header, r.StatusCode, nil } +// StripeBillingPOST performs a `POST /stripe/billing` +func (at *AccountsTester) StripeBillingPOST() (http.Header, int, error) { + r, err := at.Request(http.MethodPost, "/stripe/billing", nil, nil, nil, nil) + // We ignore the temporary redirect error because it's the expected + // behaviour of this endpoint. + if err != nil && !strings.Contains(err.Error(), "307 Temporary Redirect") { + return nil, r.StatusCode, err + } + return r.Header, r.StatusCode, nil +} + // StripeCheckoutPOST performs a `POST /stripe/checkout` func (at *AccountsTester) StripeCheckoutPOST(price string) (string, int, error) { body := struct {