Skip to content

Commit e8c9056

Browse files
authored
Merge pull request #155 from mraerino/fix/calculation-discount-taxes
Fix calculation of taxes when items are discounted
2 parents c29af21 + cac1d86 commit e8c9056

File tree

4 files changed

+286
-109
lines changed

4 files changed

+286
-109
lines changed

api/order_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func TestOrderCreate(t *testing.T) {
159159
assert.Equal(t, "Germany", order.ShippingAddress.Country)
160160
assert.Equal(t, "Germany", order.BillingAddress.Country)
161161
assert.Equal(t, total, order.Total, fmt.Sprintf("Total should be 1105, was %v", order.Total))
162-
assert.Equal(t, taxes, order.Taxes, fmt.Sprintf("Total should be 106, was %v", order.Total))
162+
assert.Equal(t, taxes, order.Taxes, fmt.Sprintf("Total should be 106, was %v", order.Taxes))
163163
})
164164
}
165165

@@ -907,6 +907,7 @@ func validateOrder(t *testing.T, expected, actual *models.Order) {
907907
assert.Equal(expected.Taxes, actual.Taxes)
908908
assert.Equal(expected.Shipping, actual.Shipping)
909909
assert.Equal(expected.SubTotal, actual.SubTotal)
910+
assert.Equal(expected.NetTotal, actual.NetTotal)
910911
assert.Equal(expected.Total, actual.Total)
911912
assert.Equal(expected.PaymentState, actual.PaymentState)
912913
assert.Equal(expected.FulfillmentState, actual.FulfillmentState)

calculator/calculator.go

+67-50
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type Price struct {
1414

1515
Subtotal uint64
1616
Discount uint64
17+
NetTotal uint64
1718
Taxes uint64
1819
Total int64
1920
}
@@ -24,6 +25,7 @@ type ItemPrice struct {
2425

2526
Subtotal uint64
2627
Discount uint64
28+
NetTotal uint64
2729
Taxes uint64
2830
Total int64
2931
}
@@ -175,7 +177,6 @@ func (t *Tax) AppliesTo(country, productType string) bool {
175177
// currency, country, coupons, and discounts.
176178
func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params PriceParameters, log logrus.FieldLogger) Price {
177179
price := Price{}
178-
includeTaxes := settings != nil && settings.PricesIncludeTaxes
179180

180181
priceLogger := log.WithField("action", "calculate_price")
181182
if am, ok := jwtClaims["app_metadata"]; ok {
@@ -193,67 +194,39 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
193194
})
194195

195196
itemPrice := ItemPrice{Quantity: item.GetQuantity()}
196-
itemPrice.Subtotal = item.PriceInLowestUnit()
197-
198-
taxAmounts := []taxAmount{}
199-
if item.FixedVAT() != 0 {
200-
taxAmounts = append(taxAmounts, taxAmount{price: itemPrice.Subtotal, percentage: item.FixedVAT()})
201-
} else if settings != nil && item.TaxableItems() != nil && len(item.TaxableItems()) > 0 {
202-
for _, item := range item.TaxableItems() {
203-
amount := taxAmount{price: item.PriceInLowestUnit()}
204-
for _, t := range settings.Taxes {
205-
if t.AppliesTo(params.Country, item.ProductType()) {
206-
amount.percentage = t.Percentage
207-
break
208-
}
209-
}
210-
taxAmounts = append(taxAmounts, amount)
211-
}
212-
} else if settings != nil {
213-
for _, t := range settings.Taxes {
214-
if t.AppliesTo(params.Country, item.ProductType()) {
215-
taxAmounts = append(taxAmounts, taxAmount{price: itemPrice.Subtotal, percentage: t.Percentage})
216-
break
217-
}
218-
}
219-
}
220197

221-
if len(taxAmounts) != 0 {
222-
if includeTaxes {
223-
itemPrice.Subtotal = 0
224-
}
225-
for _, tax := range taxAmounts {
226-
if includeTaxes {
227-
tax.price = rint(float64(tax.price) / (100 + float64(tax.percentage)) * 100)
228-
itemPrice.Subtotal += tax.price
229-
}
230-
itemPrice.Taxes += rint(float64(tax.price) * float64(tax.percentage) / 100)
231-
}
232-
}
198+
singlePrice := item.PriceInLowestUnit()
199+
_, itemPrice.Subtotal = calculateTaxes(singlePrice, item, params, settings)
233200

201+
// apply discount to original price
234202
coupon := params.Coupon
235203
if coupon != nil && coupon.ValidForType(item.ProductType()) && coupon.ValidForProduct(item.ProductSku()) {
236-
itemPrice.Discount = calculateDiscount(itemPrice.Subtotal, itemPrice.Taxes, coupon.PercentageDiscount(), coupon.FixedDiscount(params.Currency), includeTaxes)
204+
itemPrice.Discount = calculateDiscount(singlePrice, coupon.PercentageDiscount(), coupon.FixedDiscount(params.Currency))
237205
}
238206
if settings != nil && settings.MemberDiscounts != nil {
239207
for _, discount := range settings.MemberDiscounts {
240208

241209
if jwtClaims != nil && claims.HasClaims(jwtClaims, discount.Claims) && discount.ValidForType(item.ProductType()) && discount.ValidForProduct(item.ProductSku()) {
242210
lineLogger = lineLogger.WithField("discount", discount.Claims)
243-
itemPrice.Discount += calculateDiscount(itemPrice.Subtotal, itemPrice.Taxes, discount.Percentage, discount.FixedDiscount(params.Currency), includeTaxes)
211+
itemPrice.Discount += calculateDiscount(singlePrice, discount.Percentage, discount.FixedDiscount(params.Currency))
244212
}
245213
}
246214
}
247215

248-
itemPrice.Total = int64(itemPrice.Subtotal+itemPrice.Taxes) - int64(itemPrice.Discount)
249-
if itemPrice.Total < 0 {
250-
itemPrice.Total = 0
216+
discountedPrice := uint64(0)
217+
if itemPrice.Discount < singlePrice {
218+
discountedPrice = singlePrice - itemPrice.Discount
251219
}
252220

221+
itemPrice.Taxes, itemPrice.NetTotal = calculateTaxes(discountedPrice, item, params, settings)
222+
223+
itemPrice.Total = int64(itemPrice.NetTotal + itemPrice.Taxes)
224+
253225
lineLogger.WithFields(
254226
logrus.Fields{
255227
"item_price": itemPrice.Total,
256228
"item_discount": itemPrice.Discount,
229+
"item_nettotal": itemPrice.NetTotal,
257230
"item_quantity": itemPrice.Quantity,
258231
"item_taxes": itemPrice.Taxes,
259232
}).Info("calculated item price")
@@ -262,28 +235,24 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
262235

263236
price.Subtotal += (itemPrice.Subtotal * itemPrice.Quantity)
264237
price.Discount += (itemPrice.Discount * itemPrice.Quantity)
238+
price.NetTotal += (itemPrice.NetTotal * itemPrice.Quantity)
265239
price.Taxes += (itemPrice.Taxes * itemPrice.Quantity)
266240
price.Total += (itemPrice.Total * int64(itemPrice.Quantity))
267241
}
268242

269-
price.Total = int64(price.Subtotal+price.Taxes) - int64(price.Discount)
270-
if price.Total < 0 {
271-
price.Total = 0
272-
}
243+
price.Total = int64(price.NetTotal + price.Taxes)
273244
priceLogger.WithFields(
274245
logrus.Fields{
275246
"total_price": price.Total,
276247
"total_discount": price.Discount,
248+
"total_net": price.NetTotal,
277249
"total_taxes": price.Taxes,
278250
}).Info("calculated total price")
279251

280252
return price
281253
}
282254

283-
func calculateDiscount(amountToDiscount, taxes, percentage, fixed uint64, includeTaxes bool) uint64 {
284-
if includeTaxes {
285-
amountToDiscount += taxes
286-
}
255+
func calculateDiscount(amountToDiscount, percentage, fixed uint64) uint64 {
287256
var discount uint64
288257
if percentage > 0 {
289258
discount = rint(float64(amountToDiscount) * float64(percentage) / 100)
@@ -296,6 +265,54 @@ func calculateDiscount(amountToDiscount, taxes, percentage, fixed uint64, includ
296265
return discount
297266
}
298267

268+
func calculateTaxes(amountToTax uint64, item Item, params PriceParameters, settings *Settings) (taxes uint64, subtotal uint64) {
269+
includeTaxes := settings != nil && settings.PricesIncludeTaxes
270+
originalPrice := item.PriceInLowestUnit()
271+
272+
taxAmounts := []taxAmount{}
273+
if item.FixedVAT() != 0 {
274+
taxAmounts = append(taxAmounts, taxAmount{price: amountToTax, percentage: item.FixedVAT()})
275+
} else if settings != nil && item.TaxableItems() != nil && len(item.TaxableItems()) > 0 {
276+
for _, item := range item.TaxableItems() {
277+
// because a discount may have been applied we need to determine the real price of this sub-item
278+
priceShare := float64(item.PriceInLowestUnit()) / float64(originalPrice)
279+
itemPrice := rint(float64(amountToTax) * priceShare)
280+
amount := taxAmount{price: itemPrice}
281+
for _, t := range settings.Taxes {
282+
if t.AppliesTo(params.Country, item.ProductType()) {
283+
amount.percentage = t.Percentage
284+
break
285+
}
286+
}
287+
taxAmounts = append(taxAmounts, amount)
288+
}
289+
} else if settings != nil {
290+
for _, t := range settings.Taxes {
291+
if t.AppliesTo(params.Country, item.ProductType()) {
292+
taxAmounts = append(taxAmounts, taxAmount{price: amountToTax, percentage: t.Percentage})
293+
break
294+
}
295+
}
296+
}
297+
298+
taxes = 0
299+
if len(taxAmounts) == 0 {
300+
subtotal = amountToTax
301+
return
302+
}
303+
304+
subtotal = 0
305+
for _, tax := range taxAmounts {
306+
if includeTaxes {
307+
tax.price = rint(float64(tax.price) / (100 + float64(tax.percentage)) * 100)
308+
}
309+
subtotal += tax.price
310+
taxes += rint(float64(tax.price) * float64(tax.percentage) / 100)
311+
}
312+
313+
return
314+
}
315+
299316
// Nopes - no `round` method in go
300317
// See https://github.com/golang/go/blob/master/src/math/floor.go#L58
301318

0 commit comments

Comments
 (0)