Skip to content

Commit 5e277bf

Browse files
authored
Merge pull request #156 from mraerino/fix/calculation-rounding-amounts
Fix rounding inaccuracies with quantity
2 parents 773034e + 61a8e7b commit 5e277bf

File tree

2 files changed

+44
-78
lines changed

2 files changed

+44
-78
lines changed

calculator/calculator.go

+41-71
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,38 @@ func (t *Tax) AppliesTo(country, productType string) bool {
173173
return applies
174174
}
175175

176+
func calculateAmountsForSingleItem(settings *Settings, lineLogger logrus.FieldLogger, jwtClaims map[string]interface{}, params PriceParameters, item Item, multiplier uint64) ItemPrice {
177+
itemPrice := ItemPrice{Quantity: item.GetQuantity()}
178+
179+
singlePrice := item.PriceInLowestUnit() * multiplier
180+
_, itemPrice.Subtotal = calculateTaxes(singlePrice, item, params, settings)
181+
182+
// apply discount to original price
183+
coupon := params.Coupon
184+
if coupon != nil && coupon.ValidForType(item.ProductType()) && coupon.ValidForProduct(item.ProductSku()) {
185+
itemPrice.Discount = calculateDiscount(singlePrice, coupon.PercentageDiscount(), coupon.FixedDiscount(params.Currency)*multiplier)
186+
}
187+
if settings != nil && settings.MemberDiscounts != nil {
188+
for _, discount := range settings.MemberDiscounts {
189+
190+
if jwtClaims != nil && claims.HasClaims(jwtClaims, discount.Claims) && discount.ValidForType(item.ProductType()) && discount.ValidForProduct(item.ProductSku()) {
191+
lineLogger = lineLogger.WithField("discount", discount.Claims)
192+
itemPrice.Discount += calculateDiscount(singlePrice, discount.Percentage, discount.FixedDiscount(params.Currency)*multiplier)
193+
}
194+
}
195+
}
196+
197+
discountedPrice := uint64(0)
198+
if itemPrice.Discount < singlePrice {
199+
discountedPrice = singlePrice - itemPrice.Discount
200+
}
201+
202+
itemPrice.Taxes, itemPrice.NetTotal = calculateTaxes(discountedPrice, item, params, settings)
203+
itemPrice.Total = int64(itemPrice.NetTotal + itemPrice.Taxes)
204+
205+
return itemPrice
206+
}
207+
176208
// CalculatePrice will calculate the final total price. It takes into account
177209
// currency, country, coupons, and discounts.
178210
func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params PriceParameters, log logrus.FieldLogger) Price {
@@ -193,34 +225,7 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
193225
"product_sku": item.ProductSku(),
194226
})
195227

196-
itemPrice := ItemPrice{Quantity: item.GetQuantity()}
197-
198-
singlePrice := item.PriceInLowestUnit()
199-
_, itemPrice.Subtotal = calculateTaxes(singlePrice, item, params, settings)
200-
201-
// apply discount to original price
202-
coupon := params.Coupon
203-
if coupon != nil && coupon.ValidForType(item.ProductType()) && coupon.ValidForProduct(item.ProductSku()) {
204-
itemPrice.Discount = calculateDiscount(singlePrice, coupon.PercentageDiscount(), coupon.FixedDiscount(params.Currency))
205-
}
206-
if settings != nil && settings.MemberDiscounts != nil {
207-
for _, discount := range settings.MemberDiscounts {
208-
209-
if jwtClaims != nil && claims.HasClaims(jwtClaims, discount.Claims) && discount.ValidForType(item.ProductType()) && discount.ValidForProduct(item.ProductSku()) {
210-
lineLogger = lineLogger.WithField("discount", discount.Claims)
211-
itemPrice.Discount += calculateDiscount(singlePrice, discount.Percentage, discount.FixedDiscount(params.Currency))
212-
}
213-
}
214-
}
215-
216-
discountedPrice := uint64(0)
217-
if itemPrice.Discount < singlePrice {
218-
discountedPrice = singlePrice - itemPrice.Discount
219-
}
220-
221-
itemPrice.Taxes, itemPrice.NetTotal = calculateTaxes(discountedPrice, item, params, settings)
222-
223-
itemPrice.Total = int64(itemPrice.NetTotal + itemPrice.Taxes)
228+
itemPrice := calculateAmountsForSingleItem(settings, lineLogger, jwtClaims, params, item, 1)
224229

225230
lineLogger.WithFields(
226231
logrus.Fields{
@@ -233,11 +238,13 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
233238

234239
price.Items = append(price.Items, itemPrice)
235240

236-
price.Subtotal += (itemPrice.Subtotal * itemPrice.Quantity)
237-
price.Discount += (itemPrice.Discount * itemPrice.Quantity)
238-
price.NetTotal += (itemPrice.NetTotal * itemPrice.Quantity)
239-
price.Taxes += (itemPrice.Taxes * itemPrice.Quantity)
240-
price.Total += (itemPrice.Total * int64(itemPrice.Quantity))
241+
// avoid issues with rounding when multiplying by quantity before taxation
242+
itemPriceMultiple := calculateAmountsForSingleItem(settings, lineLogger, jwtClaims, params, item, item.GetQuantity())
243+
price.Subtotal += itemPriceMultiple.Subtotal
244+
price.Discount += itemPriceMultiple.Discount
245+
price.NetTotal += itemPriceMultiple.NetTotal
246+
price.Taxes += itemPriceMultiple.Taxes
247+
price.Total += itemPriceMultiple.Total
241248
}
242249

243250
price.Total = int64(price.NetTotal + price.Taxes)
@@ -325,43 +332,6 @@ const (
325332
fracMask = 1<<shift - 1
326333
)
327334

328-
// Round returns the nearest integer, rounding half away from zero.
329-
//
330-
// Special cases are:
331-
// Round(±0) = ±0
332-
// Round(±Inf) = ±Inf
333-
// Round(NaN) = NaN
334-
func Round(x float64) float64 {
335-
// Round is a faster implementation of:
336-
//
337-
// func Round(x float64) float64 {
338-
// t := Trunc(x)
339-
// if Abs(x-t) >= 0.5 {
340-
// return t + Copysign(1, x)
341-
// }
342-
// return t
343-
// }
344-
bits := math.Float64bits(x)
345-
e := uint(bits>>shift) & mask
346-
if e < bias {
347-
// Round abs(x) < 1 including denormals.
348-
bits &= signMask // +-0
349-
if e == bias-1 {
350-
bits |= uvone // +-1
351-
}
352-
} else if e < bias+shift {
353-
// Round any abs(x) >= 1 containing a fractional component [0,1).
354-
//
355-
// Numbers with larger exponents are returned unchanged since they
356-
// must be either an integer, infinity, or NaN.
357-
const half = 1 << (shift - 1)
358-
e -= bias
359-
bits += half >> e
360-
bits &^= fracMask >> e
361-
}
362-
return math.Float64frombits(bits)
363-
}
364-
365335
func rint(x float64) uint64 {
366-
return uint64(Round(x))
336+
return uint64(math.Round(x))
367337
}

calculator/calculator_test.go

+3-7
Original file line numberDiff line numberDiff line change
@@ -208,15 +208,11 @@ func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) {
208208
params := PriceParameters{"USA", "USD", coupon, []Item{&TestItem{quantity: 2, price: 100, itemType: "test", vat: 9}}}
209209
price := CalculatePrice(settings, nil, params, testLogger)
210210

211-
// todo: This result is wrong because a rounding inaccuracy is quantified
212-
// Therefore the tax amount is not 9% of the net total
213-
// Correct net total: 165
214-
// Correct tax amount: 15
215211
validatePrice(t, price, Price{
216-
Subtotal: 184,
212+
Subtotal: 183,
217213
Discount: 20,
218-
NetTotal: 166,
219-
Taxes: 14,
214+
NetTotal: 165,
215+
Taxes: 15,
220216
Total: 180,
221217
})
222218
}

0 commit comments

Comments
 (0)