Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stripe] Set reportId on invoices after updating credits #12409

Merged
merged 1 commit into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions components/usage/pkg/apiv1/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *BillingService) UpdateInvoices(ctx context.Context, in *v1.UpdateInvoic
return nil, status.Errorf(codes.Internal, "Failed to download usage report with ID: %s", in.GetReportId())
}

credits, err := s.creditSummaryForTeams(report)
credits, err := s.creditSummaryForTeams(report, in.GetReportId())
if err != nil {
log.Log.WithError(err).Errorf("Failed to compute credit summary.")
return nil, status.Errorf(codes.InvalidArgument, "failed to compute credit summary")
Expand Down Expand Up @@ -100,7 +100,7 @@ func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcom
}, nil
}

func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[string]int64, error) {
func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport, reportID string) (map[string]stripe.CreditSummary, error) {
creditsPerTeamID := map[string]float64{}

for _, session := range sessions {
Expand All @@ -120,9 +120,12 @@ func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[str
creditsPerTeamID[id] += session.CreditsUsed
}

rounded := map[string]int64{}
rounded := map[string]stripe.CreditSummary{}
for teamID, credits := range creditsPerTeamID {
rounded[teamID] = int64(math.Ceil(credits))
rounded[teamID] = stripe.CreditSummary{
Credits: int64(math.Ceil(credits)),
ReportID: reportID,
}
}

return rounded, nil
Expand Down
35 changes: 24 additions & 11 deletions components/usage/pkg/apiv1/billing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ import (
func TestCreditSummaryForTeams(t *testing.T) {
teamID_A, teamID_B := uuid.New().String(), uuid.New().String()
teamAttributionID_A, teamAttributionID_B := db.NewTeamAttributionID(teamID_A), db.NewTeamAttributionID(teamID_B)
reportID := "report_id_1"

scenarios := []struct {
Name string
Sessions db.UsageReport
BillSessionsAfter time.Time
Expected map[string]int64
Expected map[string]stripe.CreditSummary
}{
{
Name: "no instances in report, no summary",
BillSessionsAfter: time.Time{},
Sessions: nil,
Expected: map[string]int64{},
Expected: map[string]stripe.CreditSummary{},
},
{
Name: "skips user attributions",
Expand All @@ -39,7 +40,7 @@ func TestCreditSummaryForTeams(t *testing.T) {
AttributionID: db.NewUserAttributionID(uuid.New().String()),
},
},
Expected: map[string]int64{},
Expected: map[string]stripe.CreditSummary{},
},
{
Name: "two workspace instances",
Expand All @@ -56,9 +57,12 @@ func TestCreditSummaryForTeams(t *testing.T) {
CreditsUsed: 10,
},
},
Expected: map[string]int64{
Expected: map[string]stripe.CreditSummary{
// total of 2 days runtime, at 10 credits per hour, that's 480 credits
teamID_A: 480,
teamID_A: {
Credits: 480,
ReportID: reportID,
},
},
},
{
Expand All @@ -76,10 +80,16 @@ func TestCreditSummaryForTeams(t *testing.T) {
CreditsUsed: (24) * 10,
},
},
Expected: map[string]int64{
Expected: map[string]stripe.CreditSummary{
// total of 2 days runtime, at 10 credits per hour, that's 480 credits
teamID_A: 120,
teamID_B: 240,
teamID_A: {
Credits: 120,
ReportID: reportID,
},
teamID_B: {
Credits: 240,
ReportID: reportID,
},
},
},
{
Expand All @@ -99,16 +109,19 @@ func TestCreditSummaryForTeams(t *testing.T) {
StartedAt: time.Now().AddDate(0, 0, -3),
},
},
Expected: map[string]int64{
teamID_A: 120,
Expected: map[string]stripe.CreditSummary{
teamID_A: {
Credits: 120,
ReportID: reportID,
},
},
},
}

for _, s := range scenarios {
t.Run(s.Name, func(t *testing.T) {
svc := NewBillingService(&stripe.Client{}, s.BillSessionsAfter, &gorm.DB{})
actual, err := svc.creditSummaryForTeams(s.Sessions)
actual, err := svc.creditSummaryForTeams(s.Sessions, reportID)
require.NoError(t, err)
require.Equal(t, s.Expected, actual)
})
Expand Down
43 changes: 39 additions & 4 deletions components/usage/pkg/stripe/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import (
"github.com/stripe/stripe-go/v72/client"
)

const (
reportIDMetadataKey = "reportId"
)

type Client struct {
sc *client.API
}
Expand Down Expand Up @@ -58,9 +62,14 @@ type Invoice struct {
Credits int64
}

type CreditSummary struct {
Credits int64
ReportID string
}

// UpdateUsage updates teams' Stripe subscriptions with usage data
// `usageForTeam` is a map from team name to total workspace seconds used within a billing period.
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]int64) error {
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]CreditSummary) error {
teamIds := make([]string, 0, len(creditsPerTeam))
for k := range creditsPerTeam {
teamIds = append(teamIds, k)
Expand Down Expand Up @@ -117,7 +126,7 @@ func (c *Client) findCustomers(ctx context.Context, query string) ([]*stripe.Cus
return customers, nil
}

func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, credits int64) (*UsageRecord, error) {
func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, summary CreditSummary) (*UsageRecord, error) {
subscriptions := customer.Subscriptions.Data
if len(subscriptions) != 1 {
return nil, fmt.Errorf("customer has an unexpected number of subscriptions %v (expected 1, got %d)", subscriptions, len(subscriptions))
Expand All @@ -136,15 +145,27 @@ func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Cu
Context: ctx,
},
SubscriptionItem: stripe.String(subscriptionItemId),
Quantity: stripe.Int64(credits),
Quantity: stripe.Int64(summary.Credits),
})
if err != nil {
return nil, fmt.Errorf("failed to register usage for customer %q on subscription item %s", customer.Name, subscriptionItemId)
}

invoice, err := c.GetUpcomingInvoice(ctx, customer.ID)
if err != nil {
return nil, fmt.Errorf("failed to find upcoming invoice for customer %s: %w", customer.ID, err)
}

_, err = c.UpdateInvoiceMetadata(ctx, invoice.ID, map[string]string{
reportIDMetadataKey: summary.ReportID,
})
if err != nil {
return nil, fmt.Errorf("failed to udpate invoice %s metadata with report ID: %w", invoice.ID, err)
}

return &UsageRecord{
SubscriptionItemID: subscriptionItemId,
Quantity: credits,
Quantity: summary.Credits,
}, nil
}

Expand Down Expand Up @@ -205,6 +226,20 @@ func (c *Client) GetUpcomingInvoice(ctx context.Context, customerID string) (*In
}, nil
}

func (c *Client) UpdateInvoiceMetadata(ctx context.Context, invoiceID string, metadata map[string]string) (*stripe.Invoice, error) {
invoice, err := c.sc.Invoices.Update(invoiceID, &stripe.InvoiceParams{
Params: stripe.Params{
Context: ctx,
Metadata: metadata,
},
})
if err != nil {
return nil, fmt.Errorf("failed to update invoice %s metadata: %w", invoiceID, err)
}

return invoice, nil
}

// queriesForCustomersWithTeamIds constructs Stripe query strings to find the Stripe Customer for each teamId
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
// `clausesPerQuery` is a limit enforced by the Stripe API.
Expand Down