From 67064597cd39e8f22cd6cdd9ceba5236289d5501 Mon Sep 17 00:00:00 2001 From: Kush Sharma Date: Sun, 15 Dec 2024 13:57:48 +0530 Subject: [PATCH 1/2] test: unit tests for subscription service Signed-off-by: Kush Sharma --- billing/subscription/mocks/credit_service.go | 141 +++ .../subscription/mocks/customer_service.go | 153 ++++ .../mocks/organization_service.go | 93 ++ billing/subscription/mocks/plan_service.go | 153 ++++ billing/subscription/mocks/product_service.go | 94 ++ billing/subscription/mocks/repository.go | 371 ++++++++ billing/subscription/service_test.go | 815 ++++++++++++++++++ 7 files changed, 1820 insertions(+) create mode 100644 billing/subscription/mocks/credit_service.go create mode 100644 billing/subscription/mocks/customer_service.go create mode 100644 billing/subscription/mocks/organization_service.go create mode 100644 billing/subscription/mocks/plan_service.go create mode 100644 billing/subscription/mocks/product_service.go create mode 100644 billing/subscription/mocks/repository.go create mode 100644 billing/subscription/service_test.go diff --git a/billing/subscription/mocks/credit_service.go b/billing/subscription/mocks/credit_service.go new file mode 100644 index 000000000..9a6510520 --- /dev/null +++ b/billing/subscription/mocks/credit_service.go @@ -0,0 +1,141 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + credit "github.com/raystack/frontier/billing/credit" + mock "github.com/stretchr/testify/mock" +) + +// CreditService is an autogenerated mock type for the CreditService type +type CreditService struct { + mock.Mock +} + +type CreditService_Expecter struct { + mock *mock.Mock +} + +func (_m *CreditService) EXPECT() *CreditService_Expecter { + return &CreditService_Expecter{mock: &_m.Mock} +} + +// Add provides a mock function with given fields: ctx, cred +func (_m *CreditService) Add(ctx context.Context, cred credit.Credit) error { + ret := _m.Called(ctx, cred) + + if len(ret) == 0 { + panic("no return value specified for Add") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, credit.Credit) error); ok { + r0 = rf(ctx, cred) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CreditService_Add_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Add' +type CreditService_Add_Call struct { + *mock.Call +} + +// Add is a helper method to define mock.On call +// - ctx context.Context +// - cred credit.Credit +func (_e *CreditService_Expecter) Add(ctx interface{}, cred interface{}) *CreditService_Add_Call { + return &CreditService_Add_Call{Call: _e.mock.On("Add", ctx, cred)} +} + +func (_c *CreditService_Add_Call) Run(run func(ctx context.Context, cred credit.Credit)) *CreditService_Add_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(credit.Credit)) + }) + return _c +} + +func (_c *CreditService_Add_Call) Return(_a0 error) *CreditService_Add_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *CreditService_Add_Call) RunAndReturn(run func(context.Context, credit.Credit) error) *CreditService_Add_Call { + _c.Call.Return(run) + return _c +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *CreditService) GetByID(ctx context.Context, id string) (credit.Transaction, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 credit.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (credit.Transaction, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) credit.Transaction); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(credit.Transaction) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreditService_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type CreditService_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *CreditService_Expecter) GetByID(ctx interface{}, id interface{}) *CreditService_GetByID_Call { + return &CreditService_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *CreditService_GetByID_Call) Run(run func(ctx context.Context, id string)) *CreditService_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *CreditService_GetByID_Call) Return(_a0 credit.Transaction, _a1 error) *CreditService_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CreditService_GetByID_Call) RunAndReturn(run func(context.Context, string) (credit.Transaction, error)) *CreditService_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// NewCreditService creates a new instance of CreditService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCreditService(t interface { + mock.TestingT + Cleanup(func()) +}) *CreditService { + mock := &CreditService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/billing/subscription/mocks/customer_service.go b/billing/subscription/mocks/customer_service.go new file mode 100644 index 000000000..5a72cce71 --- /dev/null +++ b/billing/subscription/mocks/customer_service.go @@ -0,0 +1,153 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + customer "github.com/raystack/frontier/billing/customer" + mock "github.com/stretchr/testify/mock" +) + +// CustomerService is an autogenerated mock type for the CustomerService type +type CustomerService struct { + mock.Mock +} + +type CustomerService_Expecter struct { + mock *mock.Mock +} + +func (_m *CustomerService) EXPECT() *CustomerService_Expecter { + return &CustomerService_Expecter{mock: &_m.Mock} +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *CustomerService) GetByID(ctx context.Context, id string) (customer.Customer, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 customer.Customer + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (customer.Customer, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) customer.Customer); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(customer.Customer) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CustomerService_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type CustomerService_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *CustomerService_Expecter) GetByID(ctx interface{}, id interface{}) *CustomerService_GetByID_Call { + return &CustomerService_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *CustomerService_GetByID_Call) Run(run func(ctx context.Context, id string)) *CustomerService_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *CustomerService_GetByID_Call) Return(_a0 customer.Customer, _a1 error) *CustomerService_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CustomerService_GetByID_Call) RunAndReturn(run func(context.Context, string) (customer.Customer, error)) *CustomerService_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, filter +func (_m *CustomerService) List(ctx context.Context, filter customer.Filter) ([]customer.Customer, error) { + ret := _m.Called(ctx, filter) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []customer.Customer + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, customer.Filter) ([]customer.Customer, error)); ok { + return rf(ctx, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, customer.Filter) []customer.Customer); ok { + r0 = rf(ctx, filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]customer.Customer) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, customer.Filter) error); ok { + r1 = rf(ctx, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CustomerService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type CustomerService_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - filter customer.Filter +func (_e *CustomerService_Expecter) List(ctx interface{}, filter interface{}) *CustomerService_List_Call { + return &CustomerService_List_Call{Call: _e.mock.On("List", ctx, filter)} +} + +func (_c *CustomerService_List_Call) Run(run func(ctx context.Context, filter customer.Filter)) *CustomerService_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(customer.Filter)) + }) + return _c +} + +func (_c *CustomerService_List_Call) Return(_a0 []customer.Customer, _a1 error) *CustomerService_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CustomerService_List_Call) RunAndReturn(run func(context.Context, customer.Filter) ([]customer.Customer, error)) *CustomerService_List_Call { + _c.Call.Return(run) + return _c +} + +// NewCustomerService creates a new instance of CustomerService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCustomerService(t interface { + mock.TestingT + Cleanup(func()) +}) *CustomerService { + mock := &CustomerService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/billing/subscription/mocks/organization_service.go b/billing/subscription/mocks/organization_service.go new file mode 100644 index 000000000..7c8b0d5f3 --- /dev/null +++ b/billing/subscription/mocks/organization_service.go @@ -0,0 +1,93 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// OrganizationService is an autogenerated mock type for the OrganizationService type +type OrganizationService struct { + mock.Mock +} + +type OrganizationService_Expecter struct { + mock *mock.Mock +} + +func (_m *OrganizationService) EXPECT() *OrganizationService_Expecter { + return &OrganizationService_Expecter{mock: &_m.Mock} +} + +// MemberCount provides a mock function with given fields: ctx, orgID +func (_m *OrganizationService) MemberCount(ctx context.Context, orgID string) (int64, error) { + ret := _m.Called(ctx, orgID) + + if len(ret) == 0 { + panic("no return value specified for MemberCount") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { + return rf(ctx, orgID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, orgID) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, orgID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// OrganizationService_MemberCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MemberCount' +type OrganizationService_MemberCount_Call struct { + *mock.Call +} + +// MemberCount is a helper method to define mock.On call +// - ctx context.Context +// - orgID string +func (_e *OrganizationService_Expecter) MemberCount(ctx interface{}, orgID interface{}) *OrganizationService_MemberCount_Call { + return &OrganizationService_MemberCount_Call{Call: _e.mock.On("MemberCount", ctx, orgID)} +} + +func (_c *OrganizationService_MemberCount_Call) Run(run func(ctx context.Context, orgID string)) *OrganizationService_MemberCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *OrganizationService_MemberCount_Call) Return(_a0 int64, _a1 error) *OrganizationService_MemberCount_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *OrganizationService_MemberCount_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *OrganizationService_MemberCount_Call { + _c.Call.Return(run) + return _c +} + +// NewOrganizationService creates a new instance of OrganizationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewOrganizationService(t interface { + mock.TestingT + Cleanup(func()) +}) *OrganizationService { + mock := &OrganizationService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/billing/subscription/mocks/plan_service.go b/billing/subscription/mocks/plan_service.go new file mode 100644 index 000000000..6c818c0de --- /dev/null +++ b/billing/subscription/mocks/plan_service.go @@ -0,0 +1,153 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + plan "github.com/raystack/frontier/billing/plan" + mock "github.com/stretchr/testify/mock" +) + +// PlanService is an autogenerated mock type for the PlanService type +type PlanService struct { + mock.Mock +} + +type PlanService_Expecter struct { + mock *mock.Mock +} + +func (_m *PlanService) EXPECT() *PlanService_Expecter { + return &PlanService_Expecter{mock: &_m.Mock} +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *PlanService) GetByID(ctx context.Context, id string) (plan.Plan, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 plan.Plan + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (plan.Plan, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) plan.Plan); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(plan.Plan) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PlanService_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type PlanService_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *PlanService_Expecter) GetByID(ctx interface{}, id interface{}) *PlanService_GetByID_Call { + return &PlanService_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *PlanService_GetByID_Call) Run(run func(ctx context.Context, id string)) *PlanService_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *PlanService_GetByID_Call) Return(_a0 plan.Plan, _a1 error) *PlanService_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PlanService_GetByID_Call) RunAndReturn(run func(context.Context, string) (plan.Plan, error)) *PlanService_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, filter +func (_m *PlanService) List(ctx context.Context, filter plan.Filter) ([]plan.Plan, error) { + ret := _m.Called(ctx, filter) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []plan.Plan + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, plan.Filter) ([]plan.Plan, error)); ok { + return rf(ctx, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, plan.Filter) []plan.Plan); ok { + r0 = rf(ctx, filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]plan.Plan) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, plan.Filter) error); ok { + r1 = rf(ctx, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PlanService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type PlanService_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - filter plan.Filter +func (_e *PlanService_Expecter) List(ctx interface{}, filter interface{}) *PlanService_List_Call { + return &PlanService_List_Call{Call: _e.mock.On("List", ctx, filter)} +} + +func (_c *PlanService_List_Call) Run(run func(ctx context.Context, filter plan.Filter)) *PlanService_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(plan.Filter)) + }) + return _c +} + +func (_c *PlanService_List_Call) Return(_a0 []plan.Plan, _a1 error) *PlanService_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PlanService_List_Call) RunAndReturn(run func(context.Context, plan.Filter) ([]plan.Plan, error)) *PlanService_List_Call { + _c.Call.Return(run) + return _c +} + +// NewPlanService creates a new instance of PlanService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPlanService(t interface { + mock.TestingT + Cleanup(func()) +}) *PlanService { + mock := &PlanService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/billing/subscription/mocks/product_service.go b/billing/subscription/mocks/product_service.go new file mode 100644 index 000000000..d62dca0da --- /dev/null +++ b/billing/subscription/mocks/product_service.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + product "github.com/raystack/frontier/billing/product" + mock "github.com/stretchr/testify/mock" +) + +// ProductService is an autogenerated mock type for the ProductService type +type ProductService struct { + mock.Mock +} + +type ProductService_Expecter struct { + mock *mock.Mock +} + +func (_m *ProductService) EXPECT() *ProductService_Expecter { + return &ProductService_Expecter{mock: &_m.Mock} +} + +// GetByProviderID provides a mock function with given fields: ctx, id +func (_m *ProductService) GetByProviderID(ctx context.Context, id string) (product.Product, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByProviderID") + } + + var r0 product.Product + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (product.Product, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) product.Product); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(product.Product) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ProductService_GetByProviderID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByProviderID' +type ProductService_GetByProviderID_Call struct { + *mock.Call +} + +// GetByProviderID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *ProductService_Expecter) GetByProviderID(ctx interface{}, id interface{}) *ProductService_GetByProviderID_Call { + return &ProductService_GetByProviderID_Call{Call: _e.mock.On("GetByProviderID", ctx, id)} +} + +func (_c *ProductService_GetByProviderID_Call) Run(run func(ctx context.Context, id string)) *ProductService_GetByProviderID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *ProductService_GetByProviderID_Call) Return(_a0 product.Product, _a1 error) *ProductService_GetByProviderID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ProductService_GetByProviderID_Call) RunAndReturn(run func(context.Context, string) (product.Product, error)) *ProductService_GetByProviderID_Call { + _c.Call.Return(run) + return _c +} + +// NewProductService creates a new instance of ProductService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewProductService(t interface { + mock.TestingT + Cleanup(func()) +}) *ProductService { + mock := &ProductService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/billing/subscription/mocks/repository.go b/billing/subscription/mocks/repository.go new file mode 100644 index 000000000..2dc69c965 --- /dev/null +++ b/billing/subscription/mocks/repository.go @@ -0,0 +1,371 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + subscription "github.com/raystack/frontier/billing/subscription" + mock "github.com/stretchr/testify/mock" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, subs +func (_m *Repository) Create(ctx context.Context, subs subscription.Subscription) (subscription.Subscription, error) { + ret := _m.Called(ctx, subs) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 subscription.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, subscription.Subscription) (subscription.Subscription, error)); ok { + return rf(ctx, subs) + } + if rf, ok := ret.Get(0).(func(context.Context, subscription.Subscription) subscription.Subscription); ok { + r0 = rf(ctx, subs) + } else { + r0 = ret.Get(0).(subscription.Subscription) + } + + if rf, ok := ret.Get(1).(func(context.Context, subscription.Subscription) error); ok { + r1 = rf(ctx, subs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type Repository_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - subs subscription.Subscription +func (_e *Repository_Expecter) Create(ctx interface{}, subs interface{}) *Repository_Create_Call { + return &Repository_Create_Call{Call: _e.mock.On("Create", ctx, subs)} +} + +func (_c *Repository_Create_Call) Run(run func(ctx context.Context, subs subscription.Subscription)) *Repository_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(subscription.Subscription)) + }) + return _c +} + +func (_c *Repository_Create_Call) Return(_a0 subscription.Subscription, _a1 error) *Repository_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, subscription.Subscription) (subscription.Subscription, error)) *Repository_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, id +func (_m *Repository) Delete(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Repository_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) Delete(ctx interface{}, id interface{}) *Repository_Delete_Call { + return &Repository_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} +} + +func (_c *Repository_Delete_Call) Run(run func(ctx context.Context, id string)) *Repository_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_Delete_Call) Return(_a0 error) *Repository_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_Delete_Call) RunAndReturn(run func(context.Context, string) error) *Repository_Delete_Call { + _c.Call.Return(run) + return _c +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *Repository) GetByID(ctx context.Context, id string) (subscription.Subscription, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 subscription.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (subscription.Subscription, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) subscription.Subscription); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(subscription.Subscription) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type Repository_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) GetByID(ctx interface{}, id interface{}) *Repository_GetByID_Call { + return &Repository_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *Repository_GetByID_Call) Run(run func(ctx context.Context, id string)) *Repository_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByID_Call) Return(_a0 subscription.Subscription, _a1 error) *Repository_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByID_Call) RunAndReturn(run func(context.Context, string) (subscription.Subscription, error)) *Repository_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// GetByProviderID provides a mock function with given fields: ctx, id +func (_m *Repository) GetByProviderID(ctx context.Context, id string) (subscription.Subscription, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByProviderID") + } + + var r0 subscription.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (subscription.Subscription, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) subscription.Subscription); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(subscription.Subscription) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByProviderID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByProviderID' +type Repository_GetByProviderID_Call struct { + *mock.Call +} + +// GetByProviderID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) GetByProviderID(ctx interface{}, id interface{}) *Repository_GetByProviderID_Call { + return &Repository_GetByProviderID_Call{Call: _e.mock.On("GetByProviderID", ctx, id)} +} + +func (_c *Repository_GetByProviderID_Call) Run(run func(ctx context.Context, id string)) *Repository_GetByProviderID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByProviderID_Call) Return(_a0 subscription.Subscription, _a1 error) *Repository_GetByProviderID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByProviderID_Call) RunAndReturn(run func(context.Context, string) (subscription.Subscription, error)) *Repository_GetByProviderID_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, filter +func (_m *Repository) List(ctx context.Context, filter subscription.Filter) ([]subscription.Subscription, error) { + ret := _m.Called(ctx, filter) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []subscription.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, subscription.Filter) ([]subscription.Subscription, error)); ok { + return rf(ctx, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, subscription.Filter) []subscription.Subscription); ok { + r0 = rf(ctx, filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]subscription.Subscription) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, subscription.Filter) error); ok { + r1 = rf(ctx, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type Repository_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - filter subscription.Filter +func (_e *Repository_Expecter) List(ctx interface{}, filter interface{}) *Repository_List_Call { + return &Repository_List_Call{Call: _e.mock.On("List", ctx, filter)} +} + +func (_c *Repository_List_Call) Run(run func(ctx context.Context, filter subscription.Filter)) *Repository_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(subscription.Filter)) + }) + return _c +} + +func (_c *Repository_List_Call) Return(_a0 []subscription.Subscription, _a1 error) *Repository_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_List_Call) RunAndReturn(run func(context.Context, subscription.Filter) ([]subscription.Subscription, error)) *Repository_List_Call { + _c.Call.Return(run) + return _c +} + +// UpdateByID provides a mock function with given fields: ctx, subs +func (_m *Repository) UpdateByID(ctx context.Context, subs subscription.Subscription) (subscription.Subscription, error) { + ret := _m.Called(ctx, subs) + + if len(ret) == 0 { + panic("no return value specified for UpdateByID") + } + + var r0 subscription.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, subscription.Subscription) (subscription.Subscription, error)); ok { + return rf(ctx, subs) + } + if rf, ok := ret.Get(0).(func(context.Context, subscription.Subscription) subscription.Subscription); ok { + r0 = rf(ctx, subs) + } else { + r0 = ret.Get(0).(subscription.Subscription) + } + + if rf, ok := ret.Get(1).(func(context.Context, subscription.Subscription) error); ok { + r1 = rf(ctx, subs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_UpdateByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateByID' +type Repository_UpdateByID_Call struct { + *mock.Call +} + +// UpdateByID is a helper method to define mock.On call +// - ctx context.Context +// - subs subscription.Subscription +func (_e *Repository_Expecter) UpdateByID(ctx interface{}, subs interface{}) *Repository_UpdateByID_Call { + return &Repository_UpdateByID_Call{Call: _e.mock.On("UpdateByID", ctx, subs)} +} + +func (_c *Repository_UpdateByID_Call) Run(run func(ctx context.Context, subs subscription.Subscription)) *Repository_UpdateByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(subscription.Subscription)) + }) + return _c +} + +func (_c *Repository_UpdateByID_Call) Return(_a0 subscription.Subscription, _a1 error) *Repository_UpdateByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_UpdateByID_Call) RunAndReturn(run func(context.Context, subscription.Subscription) (subscription.Subscription, error)) *Repository_UpdateByID_Call { + _c.Call.Return(run) + return _c +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/billing/subscription/service_test.go b/billing/subscription/service_test.go new file mode 100644 index 000000000..4567e3b04 --- /dev/null +++ b/billing/subscription/service_test.go @@ -0,0 +1,815 @@ +package subscription_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/raystack/frontier/billing/product" + stripemock "github.com/raystack/frontier/billing/stripetest/mocks" + + "github.com/raystack/frontier/billing" + "github.com/raystack/frontier/billing/customer" + "github.com/raystack/frontier/billing/plan" + "github.com/raystack/frontier/billing/subscription" + "github.com/raystack/frontier/billing/subscription/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stripe/stripe-go/v79" + "github.com/stripe/stripe-go/v79/client" +) + +func TestService_GetByID(t *testing.T) { + tests := []struct { + name string + id string + setup func(*mocks.Repository) + want subscription.Subscription + wantErr error + }{ + { + name: "should return subscription if found", + id: "test-id", + setup: func(r *mocks.Repository) { + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{ + ID: "test-id", + PlanID: "plan-1", + State: "active", + }, nil) + }, + want: subscription.Subscription{ + ID: "test-id", + PlanID: "plan-1", + State: "active", + }, + }, + { + name: "should return error if not found", + id: "test-id", + setup: func(r *mocks.Repository) { + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{}, subscription.ErrNotFound) + }, + wantErr: subscription.ErrNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + if tt.setup != nil { + tt.setup(mockRepo) + } + + svc := subscription.NewService(nil, billing.Config{}, mockRepo, nil, nil, nil, nil, nil) + got, err := svc.GetByID(context.Background(), tt.id) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.wantErr, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_Cancel(t *testing.T) { + tests := []struct { + name string + id string + immediate bool + setup func(*mocks.Repository, *stripe.Subscription, *stripe.SubscriptionSchedule) + wantErr error + }{ + { + name: "should return error if subscription not found", + id: "test-id", + immediate: true, + setup: func(r *mocks.Repository, _ *stripe.Subscription, _ *stripe.SubscriptionSchedule) { + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{}, subscription.ErrNotFound) + }, + wantErr: subscription.ErrNotFound, + }, + { + name: "should cancel subscription immediately", + id: "test-id", + immediate: true, + setup: func(r *mocks.Repository, stripeSub *stripe.Subscription, _ *stripe.SubscriptionSchedule) { + // Setup repository expectations + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{ + ID: "test-id", + State: subscription.StateActive.String(), + ProviderID: "stripe-sub-id", + }, nil) + + // Setup stripe subscription response + *stripeSub = stripe.Subscription{ + ID: "stripe-sub-id", + Status: stripe.SubscriptionStatusCanceled, + CanceledAt: time.Now().Unix(), + } + + r.EXPECT().UpdateByID(mock.Anything, mock.MatchedBy(func(s subscription.Subscription) bool { + return s.ID == "test-id" && s.State == subscription.StateCanceled.String() + })).Return(subscription.Subscription{}, nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockBackend := stripemock.NewBackend(t) + + // Create stripe client with mock backend + stripeClient := client.New("key_123", &stripe.Backends{ + API: mockBackend, + }) + + stripeSub := &stripe.Subscription{} + stripeSched := &stripe.SubscriptionSchedule{} + + if tt.setup != nil { + tt.setup(mockRepo, stripeSub, stripeSched) + } + + // Setup mock backend expectations + if stripeSub.ID != "" { + mockBackend.EXPECT().Call("GET", "/v1/subscriptions/"+stripeSub.ID, "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sub := v.(*stripe.Subscription) + sub.ID = stripeSub.ID + sub.Status = stripe.SubscriptionStatusActive + }).Return(nil) + + mockBackend.EXPECT().Call("POST", "/v1/subscription_schedules", "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sched := v.(*stripe.SubscriptionSchedule) + sched.ID = "sched_123" + }).Return(nil) + + mockBackend.EXPECT().Call("DELETE", "/v1/subscriptions/"+stripeSub.ID, "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sub := v.(*stripe.Subscription) + sub.ID = stripeSub.ID + sub.Status = stripe.SubscriptionStatusCanceled + sub.CanceledAt = time.Now().Unix() + }).Return(nil) + } + + svc := subscription.NewService(stripeClient, billing.Config{}, mockRepo, nil, nil, nil, nil, nil) + _, err := svc.Cancel(context.Background(), tt.id, tt.immediate) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.wantErr, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestService_ChangePlan(t *testing.T) { + tests := []struct { + name string + id string + change subscription.ChangeRequest + setup func(*mocks.Repository, *mocks.PlanService, *mocks.CustomerService, *mocks.OrganizationService) + want subscription.Phase + wantErr error + }{ + { + name: "should return error if subscription not found", + id: "test-id", + change: subscription.ChangeRequest{ + PlanID: "new-plan", + Immediate: true, + }, + setup: func(r *mocks.Repository, p *mocks.PlanService, c *mocks.CustomerService, o *mocks.OrganizationService) { + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{}, subscription.ErrNotFound) + }, + wantErr: subscription.ErrNotFound, + }, + { + name: "should return error if subscription not active", + id: "test-id", + change: subscription.ChangeRequest{ + PlanID: "new-plan", + Immediate: true, + }, + setup: func(r *mocks.Repository, p *mocks.PlanService, c *mocks.CustomerService, o *mocks.OrganizationService) { + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{ + ID: "test-id", + State: subscription.StateCanceled.String(), + }, nil) + }, + wantErr: errors.New("only active subscriptions can be changed"), + }, + { + name: "should return error if plan not found", + id: "test-id", + change: subscription.ChangeRequest{ + PlanID: "new-plan", + Immediate: true, + }, + setup: func(r *mocks.Repository, p *mocks.PlanService, c *mocks.CustomerService, o *mocks.OrganizationService) { + r.EXPECT().GetByID(mock.Anything, "test-id").Return(subscription.Subscription{ + ID: "test-id", + State: subscription.StateActive.String(), + }, nil) + p.EXPECT().GetByID(mock.Anything, "new-plan").Return(plan.Plan{}, errors.New("plan not found")) + }, + wantErr: errors.New("plan not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockPlanSvc := mocks.NewPlanService(t) + mockCustomerSvc := mocks.NewCustomerService(t) + mockOrgSvc := mocks.NewOrganizationService(t) + + if tt.setup != nil { + tt.setup(mockRepo, mockPlanSvc, mockCustomerSvc, mockOrgSvc) + } + + svc := subscription.NewService(nil, billing.Config{}, mockRepo, mockCustomerSvc, mockPlanSvc, mockOrgSvc, nil, nil) + got, err := svc.ChangePlan(context.Background(), tt.id, tt.change) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_SyncWithProvider(t *testing.T) { + tests := []struct { + name string + cust customer.Customer + setup func(*mocks.Repository, *mocks.CustomerService, *mocks.PlanService, *mocks.OrganizationService, *mocks.ProductService, *stripe.Subscription) + wantErr error + }{ + { + name: "should sync active subscriptions", + cust: customer.Customer{ + ID: "customer-1", + OrgID: "org-1", + State: customer.ActiveState, + }, + setup: func(r *mocks.Repository, c *mocks.CustomerService, p *mocks.PlanService, o *mocks.OrganizationService, prodSvc *mocks.ProductService, stripeSub *stripe.Subscription) { + // Setup stripe subscription first since it affects the flow + *stripeSub = stripe.Subscription{ + ID: "stripe-sub-1", + Status: stripe.SubscriptionStatusActive, + Items: &stripe.SubscriptionItemList{ + Data: []*stripe.SubscriptionItem{ + { + ID: "si_123", + Price: &stripe.Price{ + ID: "price_123", + Recurring: &stripe.PriceRecurring{ + Interval: stripe.PriceRecurringIntervalMonth, + }, + Product: &stripe.Product{ + ID: "prod_123", + }, + }, + Quantity: 1, + Metadata: map[string]string{ + "price_id": "price-1", + "managed_by": "frontier", + }, + }, + }, + }, + Schedule: &stripe.SubscriptionSchedule{ + ID: "sched_123", + CurrentPhase: &stripe.SubscriptionScheduleCurrentPhase{ + StartDate: time.Now().Unix(), + EndDate: time.Now().Add(24 * time.Hour).Unix(), + }, + Phases: []*stripe.SubscriptionSchedulePhase{ + { + StartDate: time.Now().Unix(), + EndDate: time.Now().Add(24 * time.Hour).Unix(), + Items: []*stripe.SubscriptionSchedulePhaseItem{ + { + Price: &stripe.Price{ + ID: "price_123", + Recurring: &stripe.PriceRecurring{ + Interval: stripe.PriceRecurringIntervalMonth, + }, + Product: &stripe.Product{ + ID: "prod_123", + }, + }, + Quantity: 1, + }, + }, + Metadata: map[string]string{ + "plan_id": "plan-1", + "managed_by": "frontier", + }, + }, + { + StartDate: time.Now().Add(24 * time.Hour).Unix(), + EndDate: time.Now().Add(48 * time.Hour).Unix(), + Items: []*stripe.SubscriptionSchedulePhaseItem{ + { + Price: &stripe.Price{ + ID: "price_123", + Recurring: &stripe.PriceRecurring{ + Interval: stripe.PriceRecurringIntervalMonth, + }, + Product: &stripe.Product{ + ID: "prod_123", + }, + }, + Quantity: 1, + }, + }, + Metadata: map[string]string{ + "plan_id": "plan-1", + "managed_by": "frontier", + }, + }, + }, + }, + } + + // 1. List subscriptions + r.EXPECT().List(mock.Anything, mock.MatchedBy(func(f subscription.Filter) bool { + return f.CustomerID == "customer-1" + })).Return([]subscription.Subscription{ + { + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + ProviderID: "stripe-sub-1", + }, + }, nil) + + // 2. Update subscription state from stripe + r.EXPECT().UpdateByID(mock.Anything, mock.MatchedBy(func(s subscription.Subscription) bool { + return s.ID == "sub-1" && s.State == subscription.StateActive.String() + })).Return(subscription.Subscription{ + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + ProviderID: "stripe-sub-1", + }, nil) + + // 3. Get plan details for active subscription + p.EXPECT().GetByID(mock.Anything, "plan-1").Return(plan.Plan{ + ID: "plan-1", + Products: []product.Product{ + { + ID: "product-1", + Behavior: product.PerSeatBehavior, + Prices: []product.Price{ + { + ID: "price-1", + ProviderID: "price_123", + Interval: string(stripe.PriceRecurringIntervalMonth), + }, + }, + }, + }, + Metadata: map[string]interface{}{ + "price_id": "price_123", + }, + }, nil).Times(2) // Called for both current and next phase + + // 4. Get member count for quantity update + o.EXPECT().MemberCount(mock.Anything, "org-1").Return(int64(2), nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockCustomerSvc := mocks.NewCustomerService(t) + mockPlanSvc := mocks.NewPlanService(t) + mockOrgSvc := mocks.NewOrganizationService(t) + mockProdSvc := mocks.NewProductService(t) + mockBackend := stripemock.NewBackend(t) + + stripeClient := client.New("key_123", &stripe.Backends{ + API: mockBackend, + }) + + stripeSub := &stripe.Subscription{} + + if tt.setup != nil { + tt.setup(mockRepo, mockCustomerSvc, mockPlanSvc, mockOrgSvc, mockProdSvc, stripeSub) + } + + if stripeSub.ID != "" { + // Mock GET subscription call + mockBackend.EXPECT().Call("GET", "/v1/subscriptions/"+stripeSub.ID, "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sub := v.(*stripe.Subscription) + *sub = *stripeSub + }).Return(nil).Once() + + // Mock GET schedule call + mockBackend.EXPECT().Call("GET", "/v1/subscription_schedules/sched_123", "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sched := v.(*stripe.SubscriptionSchedule) + *sched = *stripeSub.Schedule + }).Return(nil).Once() + + // Mock POST subscription update call for quantity + mockBackend.EXPECT().Call("POST", "/v1/subscriptions/"+stripeSub.ID, "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sub := v.(*stripe.Subscription) + *sub = *stripeSub + sub.Items.Data[0].Quantity = 2 // Updated quantity + }).Return(nil).Once() + + // Mock POST subscription schedule update call for both phases + mockBackend.EXPECT().Call("POST", "/v1/subscription_schedules/sched_123", "key_123", + mock.Anything, mock.Anything).Run(func(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) { + sched := v.(*stripe.SubscriptionSchedule) + *sched = *stripeSub.Schedule + sched.Phases[0].Items[0].Quantity = 2 // Updated quantity for current phase + sched.Phases[1].Items[0].Quantity = 2 // Updated quantity for next phase + }).Return(nil).Once() + } + + svc := subscription.NewService( + stripeClient, + billing.Config{ + ProductConfig: billing.ProductConfig{ + SeatChangeBehavior: "exact", + }, + }, + mockRepo, + mockCustomerSvc, + mockPlanSvc, + mockOrgSvc, + mockProdSvc, + nil, + ) + + err := svc.SyncWithProvider(context.Background(), tt.cust) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + return + } + + assert.NoError(t, err) + }) + } +} + +func TestService_HasUserSubscribedBefore(t *testing.T) { + tests := []struct { + name string + customerID string + planID string + setup func(*mocks.Repository) + want bool + wantErr error + }{ + { + name: "should return true if user has active subscription", + customerID: "customer-1", + planID: "plan-1", + setup: func(r *mocks.Repository) { + r.EXPECT().List(mock.Anything, subscription.Filter{ + CustomerID: "customer-1", + }).Return([]subscription.Subscription{ + { + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + }, + }, nil) + }, + want: true, + }, + { + name: "should return true if user had subscription in history", + customerID: "customer-1", + planID: "plan-1", + setup: func(r *mocks.Repository) { + r.EXPECT().List(mock.Anything, subscription.Filter{ + CustomerID: "customer-1", + }).Return([]subscription.Subscription{ + { + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-2", + State: subscription.StateActive.String(), + PlanHistory: []subscription.Phase{ + { + PlanID: "plan-1", + }, + }, + }, + }, nil) + }, + want: true, + }, + { + name: "should return false if user never subscribed", + customerID: "customer-1", + planID: "plan-1", + setup: func(r *mocks.Repository) { + r.EXPECT().List(mock.Anything, subscription.Filter{ + CustomerID: "customer-1", + }).Return([]subscription.Subscription{ + { + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-2", + State: subscription.StateActive.String(), + }, + }, nil) + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + if tt.setup != nil { + tt.setup(mockRepo) + } + + svc := subscription.NewService(nil, billing.Config{}, mockRepo, nil, nil, nil, nil, nil) + got, err := svc.HasUserSubscribedBefore(context.Background(), tt.customerID, tt.planID) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.wantErr, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_Create(t *testing.T) { + tests := []struct { + name string + sub subscription.Subscription + setup func(*mocks.Repository) + want subscription.Subscription + wantErr error + }{ + { + name: "should create new subscription", + sub: subscription.Subscription{ + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + Metadata: map[string]interface{}{ + "test": "data", + }, + }, + setup: func(r *mocks.Repository) { + r.EXPECT().Create(mock.Anything, mock.MatchedBy(func(s subscription.Subscription) bool { + return s.CustomerID == "customer-1" && s.PlanID == "plan-1" + })).Return(subscription.Subscription{ + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + Metadata: map[string]interface{}{ + "test": "data", + }, + }, nil) + }, + want: subscription.Subscription{ + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + Metadata: map[string]interface{}{ + "test": "data", + }, + }, + }, + { + name: "should return error if repository fails", + sub: subscription.Subscription{ + CustomerID: "customer-1", + PlanID: "plan-1", + }, + setup: func(r *mocks.Repository) { + r.EXPECT().Create(mock.Anything, mock.Anything).Return(subscription.Subscription{}, errors.New("db error")) + }, + wantErr: errors.New("db error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + if tt.setup != nil { + tt.setup(mockRepo) + } + + svc := subscription.NewService(nil, billing.Config{}, mockRepo, nil, nil, nil, nil, nil) + got, err := svc.Create(context.Background(), tt.sub) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_List(t *testing.T) { + tests := []struct { + name string + filter subscription.Filter + setup func(*mocks.Repository) + want []subscription.Subscription + wantErr error + }{ + { + name: "should list all subscriptions", + filter: subscription.Filter{ + CustomerID: "customer-1", + }, + setup: func(r *mocks.Repository) { + r.EXPECT().List(mock.Anything, subscription.Filter{ + CustomerID: "customer-1", + }).Return([]subscription.Subscription{ + { + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + }, + { + ID: "sub-2", + CustomerID: "customer-1", + PlanID: "plan-2", + State: subscription.StateCanceled.String(), + }, + }, nil) + }, + want: []subscription.Subscription{ + { + ID: "sub-1", + CustomerID: "customer-1", + PlanID: "plan-1", + State: subscription.StateActive.String(), + }, + { + ID: "sub-2", + CustomerID: "customer-1", + PlanID: "plan-2", + State: subscription.StateCanceled.String(), + }, + }, + }, + { + name: "should return error if repository fails", + filter: subscription.Filter{ + CustomerID: "customer-1", + }, + setup: func(r *mocks.Repository) { + r.EXPECT().List(mock.Anything, subscription.Filter{ + CustomerID: "customer-1", + }).Return(nil, errors.New("db error")) + }, + wantErr: errors.New("db error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + if tt.setup != nil { + tt.setup(mockRepo) + } + + svc := subscription.NewService(nil, billing.Config{}, mockRepo, nil, nil, nil, nil, nil) + got, err := svc.List(context.Background(), tt.filter) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestService_DeleteByCustomer(t *testing.T) { + tests := []struct { + name string + cust customer.Customer + setup func(*mocks.Repository, *stripemock.Backend, *mocks.PlanService, *mocks.ProductService) + wantErr error + }{ + { + name: "should return error if listing subscriptions fails", + cust: customer.Customer{ + ID: "customer-1", + State: customer.ActiveState, + }, + setup: func(r *mocks.Repository, b *stripemock.Backend, p *mocks.PlanService, ps *mocks.ProductService) { + r.EXPECT().List(mock.Anything, subscription.Filter{ + CustomerID: "customer-1", + }).Return(nil, errors.New("db error")) + }, + wantErr: errors.New("db error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockBackend := stripemock.NewBackend(t) + mockPlanSvc := mocks.NewPlanService(t) + mockProdSvc := mocks.NewProductService(t) + + // Create stripe client with mock backend + stripeClient := client.New("key_123", &stripe.Backends{ + API: mockBackend, + }) + + if tt.setup != nil { + tt.setup(mockRepo, mockBackend, mockPlanSvc, mockProdSvc) + } + + svc := subscription.NewService(stripeClient, billing.Config{}, mockRepo, nil, mockPlanSvc, nil, mockProdSvc, nil) + err := svc.DeleteByCustomer(context.Background(), tt.cust) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + return + } + + assert.NoError(t, err) + }) + } +} + +func TestService_Init(t *testing.T) { + tests := []struct { + name string + config billing.Config + wantErr error + }{ + { + name: "should initialize service with cron job", + config: billing.Config{ + RefreshInterval: billing.RefreshInterval{ + Subscription: time.Minute, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := subscription.NewService(nil, tt.config, nil, nil, nil, nil, nil, nil) + err := svc.Init(context.Background()) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + return + } + + assert.NoError(t, err) + assert.NoError(t, svc.Close()) + }) + } +} From 7655c96d183458563ba3467e4ff6bf3b88605099 Mon Sep 17 00:00:00 2001 From: Kush Sharma Date: Sun, 15 Dec 2024 14:18:29 +0530 Subject: [PATCH 2/2] refactor sync method Signed-off-by: Kush Sharma --- billing/subscription/service.go | 243 ++++++++++++++++---------------- 1 file changed, 123 insertions(+), 120 deletions(-) diff --git a/billing/subscription/service.go b/billing/subscription/service.go index 8fb0ef95d..209ad90cc 100644 --- a/billing/subscription/service.go +++ b/billing/subscription/service.go @@ -203,7 +203,6 @@ func (s *Service) SyncWithProvider(ctx context.Context, customr customer.Custome var subErrs []error for _, sub := range subs { if ctx.Err() != nil { - // stop processing if context is done break } @@ -211,137 +210,57 @@ func (s *Service) SyncWithProvider(ctx context.Context, customr customer.Custome continue } - stripeSubscription, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) - if err != nil { - if errors.Is(err, ErrSubscriptionOnProviderNotFound) { - // if it's a test resource, mark it as canceled - if val, ok := sub.Metadata[ProviderTestResource].(bool); ok && val { - sub.State = StateCanceled.String() - sub.CanceledAt = time.Now().UTC() - if _, err := s.repository.UpdateByID(ctx, sub); err != nil { - subErrs = append(subErrs, err) - } - } else { - subErrs = append(subErrs, fmt.Errorf("%s: %w", sub.ID, err)) - } - } else { - subErrs = append(subErrs, err) - } - continue + if err := s.syncSubscription(ctx, sub, customr); err != nil { + subErrs = append(subErrs, fmt.Errorf("failed to sync subscription %s: %w", sub.ID, err)) } + } - updateNeeded := false - if sub.State != string(stripeSubscription.Status) { - updateNeeded = true - sub.State = string(stripeSubscription.Status) - } - if stripeSubscription.CanceledAt > 0 && sub.CanceledAt.Unix() != stripeSubscription.CanceledAt { - updateNeeded = true - sub.CanceledAt = utils.AsTimeFromEpoch(stripeSubscription.CanceledAt) - } - if stripeSubscription.EndedAt > 0 && sub.EndedAt.Unix() != stripeSubscription.EndedAt { - updateNeeded = true - sub.EndedAt = utils.AsTimeFromEpoch(stripeSubscription.EndedAt) - } - if stripeSubscription.TrialEnd > 0 && sub.TrialEndsAt.Unix() != stripeSubscription.TrialEnd { - updateNeeded = true - sub.TrialEndsAt = utils.AsTimeFromEpoch(stripeSubscription.TrialEnd) - } - if stripeSubscription.CurrentPeriodStart > 0 && sub.CurrentPeriodStartAt.Unix() != stripeSubscription.CurrentPeriodStart { - updateNeeded = true - sub.CurrentPeriodStartAt = utils.AsTimeFromEpoch(stripeSubscription.CurrentPeriodStart) - } - if stripeSubscription.CurrentPeriodEnd > 0 && sub.CurrentPeriodEndAt.Unix() != stripeSubscription.CurrentPeriodEnd { - updateNeeded = true - sub.CurrentPeriodEndAt = utils.AsTimeFromEpoch(stripeSubscription.CurrentPeriodEnd) - } - if stripeSubscription.BillingCycleAnchor > 0 && sub.BillingCycleAnchorAt.Unix() != stripeSubscription.BillingCycleAnchor { - updateNeeded = true - sub.BillingCycleAnchorAt = utils.AsTimeFromEpoch(stripeSubscription.BillingCycleAnchor) - } - - // update plan id if it's changed - currentPlanID, nextPlanID, err := s.getPlanFromSchedule(ctx, stripeSchedule) - if errors.Is(err, ErrNoPhaseActive) { - currentPlan, err := s.findPlanByStripeSubscription(ctx, stripeSubscription) - if err != nil { - subErrs = append(subErrs, fmt.Errorf("failed to find plan from stripe subscription: %w", err)) - continue - } - currentPlanID = currentPlan.ID - } else if err != nil { - subErrs = append(subErrs, fmt.Errorf("failed to find plan from stripe schedule: %w", err)) - continue - } - - if sub.PlanID != currentPlanID { - sub.PlanID = currentPlanID - - // update plan history - if sub.PlanID != "" { - sub.PlanHistory = append(sub.PlanHistory, Phase{ - EndsAt: time.Now().UTC(), - PlanID: sub.PlanID, - }) - } - updateNeeded = true - } - - // update phase if it's changed - if sub.Phase.PlanID != nextPlanID { - sub.Phase.PlanID = nextPlanID - sub.Phase.Reason = SubscriptionChange.String() - - if stripeSchedule != nil && stripeSchedule.EndBehavior == stripe.SubscriptionScheduleEndBehaviorCancel { - sub.Phase.Reason = SubscriptionCancel.String() - } - - updateNeeded = true - } - if stripeSubscription.Schedule != nil { - if stripeSubscription.Schedule.CurrentPhase == nil && - sub.Phase.EffectiveAt.Unix() > 0 { - sub.Phase.EffectiveAt = time.Time{} - updateNeeded = true - } - if stripeSubscription.Schedule.CurrentPhase != nil && - sub.Phase.EffectiveAt.Unix() != stripeSubscription.Schedule.CurrentPhase.EndDate { - sub.Phase.EffectiveAt = utils.AsTimeFromEpoch(stripeSubscription.Schedule.CurrentPhase.EndDate) - updateNeeded = true - } - } + if len(subErrs) > 0 { + return fmt.Errorf("failed to sync subscriptions: %w", errors.Join(subErrs...)) + } + return nil +} - // update sub change if it's changed - if updateNeeded { - if _, err := s.repository.UpdateByID(ctx, sub); err != nil { +// syncSubscription handles syncing a single subscription with the provider +func (s *Service) syncSubscription(ctx context.Context, sub Subscription, customr customer.Customer) error { + stripeSubscription, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) + if err != nil { + if errors.Is(err, ErrSubscriptionOnProviderNotFound) { + // if it's a test resource, mark it as canceled + if val, ok := sub.Metadata[ProviderTestResource].(bool); ok && val { + sub.State = StateCanceled.String() + sub.CanceledAt = time.Now().UTC() + _, err := s.repository.UpdateByID(ctx, sub) return err } + return fmt.Errorf("%s: %w", sub.ID, err) } + return err + } - // TODO: We are getting an empty planID here, because the plan ID is being incorrectly set as empty in cancel scenarios of free trial. - // The check of sub.PlanID != "" is a temporary one. We need to understand why the next phase's plan id is coming up as empty. - if sub.IsActive() && sub.PlanID != "" { - subPlan, err := s.planService.GetByID(ctx, sub.PlanID) - if err != nil { - return fmt.Errorf("%w: subscription: %s plan: %s", err, sub.ID, sub.PlanID) - } + if updated, err := s.syncSubscriptionState(ctx, sub, stripeSubscription, stripeSchedule); err != nil { + return err + } else if !updated.IsActive() || sub.PlanID == "" { + return nil + } - // per seat pricing is enabled, update the quantity - if err = s.UpdateProductQuantity(ctx, customr.OrgID, subPlan, - stripeSubscription, stripeSchedule); err != nil { - return fmt.Errorf("failed to update product quantity: %w", err) - } + // Get current plan + subPlan, err := s.planService.GetByID(ctx, sub.PlanID) + if err != nil { + return fmt.Errorf("%w: subscription: %s plan: %s", err, sub.ID, sub.PlanID) + } - // subscription can also be complimented with free credits - if err := s.ensureCreditsForPlan(ctx, sub, subPlan); err != nil { - return fmt.Errorf("ensureCreditsForPlan: %w", err) - } - } + // Update product quantity if needed + if err = s.UpdateProductQuantity(ctx, customr.OrgID, subPlan, + stripeSubscription, stripeSchedule); err != nil { + return fmt.Errorf("failed to update product quantity: %w", err) } - if len(subErrs) > 0 { - return fmt.Errorf("failed to sync subscriptions: %w", errors.Join(subErrs...)) + // Ensure credits for plan + if err := s.ensureCreditsForPlan(ctx, sub, subPlan); err != nil { + return fmt.Errorf("ensureCreditsForPlan: %w", err) } + return nil } @@ -1179,3 +1098,87 @@ func (s *Service) HasUserSubscribedBefore(ctx context.Context, customerID string } return false, nil } + +// syncSubscriptionState syncs the subscription state with the provider and returns the updated subscription +func (s *Service) syncSubscriptionState(ctx context.Context, sub Subscription, + stripeSubscription *stripe.Subscription, + stripeSchedule *stripe.SubscriptionSchedule) (Subscription, error) { + updateNeeded := false + + // Sync basic subscription state + if sub.State != string(stripeSubscription.Status) { + updateNeeded = true + sub.State = string(stripeSubscription.Status) + } + + // Sync timestamps + timestamps := []struct { + current *time.Time + new int64 + }{ + {&sub.CanceledAt, stripeSubscription.CanceledAt}, + {&sub.EndedAt, stripeSubscription.EndedAt}, + {&sub.TrialEndsAt, stripeSubscription.TrialEnd}, + {&sub.CurrentPeriodStartAt, stripeSubscription.CurrentPeriodStart}, + {&sub.CurrentPeriodEndAt, stripeSubscription.CurrentPeriodEnd}, + {&sub.BillingCycleAnchorAt, stripeSubscription.BillingCycleAnchor}, + } + + for _, ts := range timestamps { + if ts.new > 0 && ts.current.Unix() != ts.new { + updateNeeded = true + *ts.current = utils.AsTimeFromEpoch(ts.new) + } + } + + // Update plan IDs + currentPlanID, nextPlanID, err := s.getPlanFromSchedule(ctx, stripeSchedule) + if errors.Is(err, ErrNoPhaseActive) { + currentPlan, err := s.findPlanByStripeSubscription(ctx, stripeSubscription) + if err != nil { + return sub, fmt.Errorf("failed to find plan from stripe subscription: %w", err) + } + currentPlanID = currentPlan.ID + } else if err != nil { + return sub, fmt.Errorf("failed to find plan from stripe schedule: %w", err) + } + + if sub.PlanID != currentPlanID { + updateNeeded = true + if sub.PlanID != "" { + sub.PlanHistory = append(sub.PlanHistory, Phase{ + EndsAt: time.Now().UTC(), + PlanID: sub.PlanID, + }) + } + sub.PlanID = currentPlanID + } + + // Update phase + if sub.Phase.PlanID != nextPlanID { + updateNeeded = true + sub.Phase.PlanID = nextPlanID + sub.Phase.Reason = SubscriptionChange.String() + + if stripeSchedule != nil && stripeSchedule.EndBehavior == stripe.SubscriptionScheduleEndBehaviorCancel { + sub.Phase.Reason = SubscriptionCancel.String() + } + } + + // Update phase effective date + if stripeSubscription.Schedule != nil { + if stripeSubscription.Schedule.CurrentPhase == nil && sub.Phase.EffectiveAt.Unix() > 0 { + updateNeeded = true + sub.Phase.EffectiveAt = time.Time{} + } else if stripeSubscription.Schedule.CurrentPhase != nil && + sub.Phase.EffectiveAt.Unix() != stripeSubscription.Schedule.CurrentPhase.EndDate { + updateNeeded = true + sub.Phase.EffectiveAt = utils.AsTimeFromEpoch(stripeSubscription.Schedule.CurrentPhase.EndDate) + } + } + + if updateNeeded { + return s.repository.UpdateByID(ctx, sub) + } + return sub, nil +}