diff --git a/pkg/accounting/accounting.go b/pkg/accounting/accounting.go index 71251039d40..4b657a9b4fb 100644 --- a/pkg/accounting/accounting.go +++ b/pkg/accounting/accounting.go @@ -40,8 +40,8 @@ type Interface interface { Release(peer swarm.Address, price uint64) // Credit increases the balance the peer has with us (we "pay" the peer). Credit(peer swarm.Address, price uint64) error - // Debit increases the balance we have with the peer (we get "paid" back). - Debit(peer swarm.Address, price uint64) error + // PrepareDebit returns an accounting Action for the later debit to be executed on and to implement shadowing reserve on the other side + PrepareDebit(peer swarm.Address, price uint64) Action // Balance returns the current balance for the given peer. Balance(peer swarm.Address) (*big.Int, error) // SurplusBalance returns the current surplus balance for the given peer. @@ -54,14 +54,37 @@ type Interface interface { CompensatedBalances() (map[string]*big.Int, error) } +// Action represents an accounting action that can be applied +type Action interface { + // Cleanup cleans up an action. Must be called wether it was applied or not. + Cleanup() + // Apply applies an action + Apply() error +} + +// debitAction represents a future debit +type debitAction struct { + accounting *Accounting + price *big.Int + peer swarm.Address + accountingPeer *accountingPeer + applied bool +} + +// PayFunc is the function used for async monetary settlement type PayFunc func(context.Context, swarm.Address, *big.Int) +// RefreshFunc is the function used for sync time-based settlement +type RefreshFunc func(context.Context, swarm.Address, *big.Int, *big.Int) (*big.Int, int64, error) + // accountingPeer holds all in-memory accounting information for one peer. type accountingPeer struct { - lock sync.Mutex // lock to be held during any accounting action for this peer - reservedBalance *big.Int // amount currently reserved for active peer interaction - paymentThreshold *big.Int // the threshold at which the peer expects us to pay - paymentOngoing bool // indicate if we are currently settling with the peer + lock sync.Mutex // lock to be held during any accounting action for this peer + reservedBalance *big.Int // amount currently reserved for active peer interaction + shadowReservedBalance *big.Int // amount currently reserved for active peer interaction + paymentThreshold *big.Int // the threshold at which the peer expects us to pay + refreshTimestamp int64 // last time we attempted time-based settlement + paymentOngoing bool // indicate if we are currently settling with the peer } // Accounting is the main implementation of the accounting interface. @@ -78,6 +101,8 @@ type Accounting struct { paymentTolerance *big.Int earlyPayment *big.Int payFunction PayFunc + refreshFunction RefreshFunc + refreshRate *big.Int pricing pricing.Interface metrics metrics } @@ -103,6 +128,7 @@ func NewAccounting( Logger logging.Logger, Store storage.StateStorer, Pricing pricing.Interface, + refreshRate *big.Int, ) (*Accounting, error) { return &Accounting{ accountingPeers: make(map[string]*accountingPeer), @@ -113,6 +139,7 @@ func NewAccounting( store: Store, pricing: Pricing, metrics: newMetrics(), + refreshRate: refreshRate, }, nil } @@ -129,17 +156,15 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint return fmt.Errorf("failed to load balance: %w", err) } } + currentDebt := new(big.Int).Neg(currentBalance) + if currentDebt.Cmp(big.NewInt(0)) < 0 { + currentDebt.SetInt64(0) + } bigPrice := new(big.Int).SetUint64(price) nextReserved := new(big.Int).Add(accountingPeer.reservedBalance, bigPrice) - expectedBalance := new(big.Int).Sub(currentBalance, nextReserved) - - // Determine if we will owe anything to the peer, if we owe less than 0, we conclude we owe nothing - expectedDebt := new(big.Int).Neg(expectedBalance) - if expectedDebt.Cmp(big.NewInt(0)) < 0 { - expectedDebt.SetInt64(0) - } + expectedDebt := new(big.Int).Add(currentDebt, nextReserved) threshold := new(big.Int).Set(accountingPeer.paymentThreshold) if threshold.Cmp(a.earlyPayment) > 0 { @@ -164,7 +189,7 @@ func (a *Accounting) Reserve(ctx context.Context, peer swarm.Address, price uint // and we are actually in debt, trigger settlement. // we pay early to avoid needlessly blocking request later when concurrent requests occur and we are already close to the payment threshold. if increasedExpectedDebt.Cmp(threshold) >= 0 && currentBalance.Cmp(big.NewInt(0)) < 0 { - err = a.settle(context.Background(), peer, accountingPeer) + err = a.settle(peer, accountingPeer) if err != nil { return fmt.Errorf("failed to settle with peer %v: %v", peer, err) } @@ -231,10 +256,9 @@ func (a *Accounting) Credit(peer swarm.Address, price uint64) error { // Settle all debt with a peer. The lock on the accountingPeer must be held when // called. -func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *accountingPeer) error { - if balance.paymentOngoing { - return nil - } +func (a *Accounting) settle(peer swarm.Address, balance *accountingPeer) error { + now := time.Now().Unix() + timeElapsed := now - balance.refreshTimestamp oldBalance, err := a.Balance(peer) if err != nil { @@ -243,101 +267,55 @@ func (a *Accounting) settle(ctx context.Context, peer swarm.Address, balance *ac } } - // Don't do anything if there is no actual debt. - // This might be the case if the peer owes us and the total reserve for a - // peer exceeds the payment treshold. - if oldBalance.Cmp(big.NewInt(0)) >= 0 { - return nil - } - - // This is safe because of the earlier check for oldbalance < 0 and the check for != MinInt64 - paymentAmount := new(big.Int).Neg(oldBalance) - - balance.paymentOngoing = true - - go a.payFunction(ctx, peer, paymentAmount) - - return nil -} - -// Debit increases the amount of debt we have with the given peer (and decreases -// existing credit). -func (a *Accounting) Debit(peer swarm.Address, price uint64) error { - accountingPeer := a.getAccountingPeer(peer) - - accountingPeer.lock.Lock() - defer accountingPeer.lock.Unlock() - - cost := new(big.Int).SetUint64(price) - // see if peer has surplus balance to deduct this transaction of - surplusBalance, err := a.SurplusBalance(peer) if err != nil { - return fmt.Errorf("failed to get surplus balance: %w", err) + return fmt.Errorf("failed to load balance: %w", err) } - if surplusBalance.Cmp(big.NewInt(0)) > 0 { - // get new surplus balance after deduct - newSurplusBalance := new(big.Int).Sub(surplusBalance, cost) - - // if nothing left for debiting, store new surplus balance and return from debit - if newSurplusBalance.Cmp(big.NewInt(0)) >= 0 { - a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is %d", peer, price, newSurplusBalance) + if surplusBalance.Cmp(big.NewInt(0)) < 0 { + return ErrInvalidValue + } - err = a.store.Put(peerSurplusBalanceKey(peer), newSurplusBalance) - if err != nil { - return fmt.Errorf("failed to persist surplus balance: %w", err) - } - // count debit operations, terminate early - a.metrics.TotalDebitedAmount.Add(float64(price)) - a.metrics.DebitEventsCount.Inc() - return nil - } + compensatedBalance := new(big.Int).Sub(oldBalance, surplusBalance) - // if surplus balance didn't cover full transaction, let's continue with leftover part as cost - debitIncrease := new(big.Int).Sub(new(big.Int).SetUint64(price), surplusBalance) + paymentAmount := new(big.Int).Neg(compensatedBalance) - // conversion to uint64 is safe because we know the relationship between the values by now, but let's make a sanity check - if debitIncrease.Cmp(big.NewInt(0)) <= 0 { - return fmt.Errorf("sanity check failed for partial debit after surplus balance drawn") + // Don't do anything if there is no actual debt. + // This might be the case if the peer owes us and the total reserve for a + // peer exceeds the payment treshold. + if !balance.paymentOngoing && paymentAmount.Cmp(big.NewInt(0)) > 0 && timeElapsed > 0 { + shadowBalance, err := a.shadowBalance(peer) + if err != nil { + return err } - cost.Set(debitIncrease) - - // if we still have something to debit, than have run out of surplus balance, - // let's store 0 as surplus balance - a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is 0", peer, debitIncrease) - err = a.store.Put(peerSurplusBalanceKey(peer), big.NewInt(0)) + acceptedAmount, timestamp, err := a.refreshFunction(context.Background(), peer, paymentAmount, shadowBalance) if err != nil { - return fmt.Errorf("failed to persist surplus balance: %w", err) + return fmt.Errorf("refresh failure: %w", err) } - } + balance.refreshTimestamp = timestamp - currentBalance, err := a.Balance(peer) - if err != nil { - if !errors.Is(err, ErrPeerNoBalance) { - return fmt.Errorf("failed to load balance: %w", err) - } - } + // Get nextBalance by safely increasing current balance with price + oldBalance = new(big.Int).Add(oldBalance, acceptedAmount) - // Get nextBalance by safely increasing current balance with price - nextBalance := new(big.Int).Add(currentBalance, cost) + a.logger.Tracef("registering refreshment sent to peer %v with amount %d, new balance is %d", peer, acceptedAmount, oldBalance) - a.logger.Tracef("debiting peer %v with price %d, new balance is %d", peer, price, nextBalance) - - err = a.store.Put(peerBalanceKey(peer), nextBalance) - if err != nil { - return fmt.Errorf("failed to persist balance: %w", err) + err = a.store.Put(peerBalanceKey(peer), oldBalance) + if err != nil { + return fmt.Errorf("settle: failed to persist balance: %w", err) + } } - a.metrics.TotalDebitedAmount.Add(float64(price)) - a.metrics.DebitEventsCount.Inc() + // if there is no monetary settlement happening, check if there is something to settle + if !balance.paymentOngoing { + paymentAmount := new(big.Int).Neg(oldBalance) - if nextBalance.Cmp(new(big.Int).Add(a.paymentThreshold, a.paymentTolerance)) >= 0 { - // peer too much in debt - a.metrics.AccountingDisconnectsCount.Inc() - return p2p.NewBlockPeerError(10000*time.Hour, ErrDisconnectThresholdExceeded) + if paymentAmount.Cmp(a.refreshRate) >= 0 { + balance.paymentOngoing = true + balance.shadowReservedBalance.Add(balance.shadowReservedBalance, paymentAmount) + go a.payFunction(context.Background(), peer, paymentAmount) + } } return nil @@ -419,7 +397,8 @@ func (a *Accounting) getAccountingPeer(peer swarm.Address) *accountingPeer { peerData, ok := a.accountingPeers[peer.String()] if !ok { peerData = &accountingPeer{ - reservedBalance: big.NewInt(0), + reservedBalance: big.NewInt(0), + shadowReservedBalance: big.NewInt(0), // initially assume the peer has the same threshold as us paymentThreshold: new(big.Int).Set(a.paymentThreshold), } @@ -541,6 +520,113 @@ func surplusBalanceKeyPeer(key []byte) (swarm.Address, error) { return addr, nil } +// NotifyPaymentThreshold should be called to notify accounting of changes in the payment threshold +func (a *Accounting) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error { + accountingPeer := a.getAccountingPeer(peer) + + accountingPeer.lock.Lock() + defer accountingPeer.lock.Unlock() + + accountingPeer.paymentThreshold.Set(paymentThreshold) + return nil +} + +func (a *Accounting) PeerDebt(peer swarm.Address) (*big.Int, error) { + + accountingPeer := a.getAccountingPeer(peer) + accountingPeer.lock.Lock() + defer accountingPeer.lock.Unlock() + + balance := new(big.Int) + zero := big.NewInt(0) + + err := a.store.Get(peerBalanceKey(peer), &balance) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + return nil, err + } + balance = big.NewInt(0) + } + + peerDebt := new(big.Int).Add(balance, accountingPeer.shadowReservedBalance) + + if peerDebt.Cmp(zero) < 0 { + return zero, nil + } + + return peerDebt, nil +} + +// ShadowBalance returns the current debt reduced by any potentially debitable amount stored in shadowReservedBalance +func (a *Accounting) shadowBalance(peer swarm.Address) (shadowBalance *big.Int, err error) { + accountingPeer := a.getAccountingPeer(peer) + balance := new(big.Int) + zero := big.NewInt(0) + + err = a.store.Get(peerBalanceKey(peer), &balance) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return zero, nil + } + return nil, err + } + + if balance.Cmp(zero) >= 0 { + return zero, nil + } + + negativeBalance := new(big.Int).Neg(balance) + + surplusBalance, err := a.SurplusBalance(peer) + if err != nil { + return nil, err + } + + debt := new(big.Int).Add(negativeBalance, surplusBalance) + + if debt.Cmp(accountingPeer.shadowReservedBalance) < 0 { + return zero, nil + } + + shadowBalance = new(big.Int).Sub(negativeBalance, accountingPeer.shadowReservedBalance) + + return shadowBalance, nil +} + +func (a *Accounting) NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) { + accountingPeer := a.getAccountingPeer(peer) + + accountingPeer.lock.Lock() + defer accountingPeer.lock.Unlock() + + accountingPeer.paymentOngoing = false + accountingPeer.shadowReservedBalance.Sub(accountingPeer.shadowReservedBalance, amount) + + if receivedError != nil { + a.logger.Warningf("accouting: payment failure %v", receivedError) + return + } + + currentBalance, err := a.Balance(peer) + if err != nil { + if !errors.Is(err, ErrPeerNoBalance) { + a.logger.Warningf("accounting: notifypaymentsent failed to load balance: %v", err) + return + } + } + + // Get nextBalance by safely increasing current balance with price + nextBalance := new(big.Int).Add(currentBalance, amount) + + a.logger.Tracef("registering payment sent to peer %v with amount %d, new balance is %d", peer, amount, nextBalance) + + err = a.store.Put(peerBalanceKey(peer), nextBalance) + if err != nil { + a.logger.Warningf("accounting: notifypaymentsent failed to persist balance: %v", err) + return + } +} + // NotifyPayment is called by Settlement when we receive a payment. func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error { accountingPeer := a.getAccountingPeer(peer) @@ -603,7 +689,7 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) } increasedSurplus := new(big.Int).Add(surplus, surplusGrowth) - a.logger.Tracef("surplus crediting peer %v with amount %d due to payment, new surplus balance is %d", peer, surplusGrowth, increasedSurplus) + a.logger.Tracef("surplus crediting peer %v with amount %d due to refreshment, new surplus balance is %d", peer, surplusGrowth, increasedSurplus) err = a.store.Put(peerSurplusBalanceKey(peer), increasedSurplus) if err != nil { @@ -614,65 +700,157 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) return nil } -// NotifyPaymentThreshold should be called to notify accounting of changes in the payment threshold -func (a *Accounting) NotifyPaymentThreshold(peer swarm.Address, paymentThreshold *big.Int) error { +// NotifyPayment is called by Settlement when we receive a payment. +func (a *Accounting) NotifyRefreshmentReceived(peer swarm.Address, amount *big.Int) error { accountingPeer := a.getAccountingPeer(peer) accountingPeer.lock.Lock() defer accountingPeer.lock.Unlock() - accountingPeer.paymentThreshold.Set(paymentThreshold) - return nil -} - -func (a *Accounting) PeerDebt(peer swarm.Address) (*big.Int, error) { - zero := big.NewInt(0) - balance, err := a.Balance(peer) + currentBalance, err := a.Balance(peer) if err != nil { - if errors.Is(err, ErrPeerNoBalance) { - return zero, nil + if !errors.Is(err, ErrPeerNoBalance) { + return err } - return nil, err } - if balance.Cmp(zero) <= 0 { - return zero, nil + // if current balance is positive, let's make a partial credit to + nextBalance := new(big.Int).Sub(currentBalance, amount) + + // Don't allow a payment to put us into debt + // This is to prevent another node tricking us into settling by settling + // first (e.g. send a bouncing cheque to trigger an honest cheque in swap). + + a.logger.Tracef("crediting peer %v with amount %d due to payment, new balance is %d", peer, amount, nextBalance) + + err = a.store.Put(peerBalanceKey(peer), nextBalance) + if err != nil { + return fmt.Errorf("failed to persist balance: %w", err) } - return balance, nil + return nil } -func (a *Accounting) NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) { +func (a *Accounting) PrepareDebit(peer swarm.Address, price uint64) Action { accountingPeer := a.getAccountingPeer(peer) accountingPeer.lock.Lock() defer accountingPeer.lock.Unlock() - accountingPeer.paymentOngoing = false + bigPrice := new(big.Int).SetUint64(price) - if receivedError != nil { - a.logger.Warningf("accouting: payment failure %v", receivedError) - return + accountingPeer.shadowReservedBalance = new(big.Int).Add(accountingPeer.shadowReservedBalance, bigPrice) + + return &debitAction{ + accounting: a, + price: bigPrice, + peer: peer, + accountingPeer: accountingPeer, + applied: false, } - currentBalance, err := a.Balance(peer) +} + +func (d *debitAction) Apply() error { + d.accountingPeer.lock.Lock() + defer d.accountingPeer.lock.Unlock() + + a := d.accounting + + cost := new(big.Int).Set(d.price) + // see if peer has surplus balance to deduct this transaction of + + surplusBalance, err := a.SurplusBalance(d.peer) + if err != nil { + return fmt.Errorf("failed to get surplus balance: %w", err) + } + if surplusBalance.Cmp(big.NewInt(0)) > 0 { + + // get new surplus balance after deduct + newSurplusBalance := new(big.Int).Sub(surplusBalance, cost) + + // if nothing left for debiting, store new surplus balance and return from debit + if newSurplusBalance.Cmp(big.NewInt(0)) >= 0 { + a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is %d", d.peer, d.price, newSurplusBalance) + + err = a.store.Put(peerSurplusBalanceKey(d.peer), newSurplusBalance) + if err != nil { + return fmt.Errorf("failed to persist surplus balance: %w", err) + } + // count debit operations, terminate early + tot, _ := big.NewFloat(0).SetInt(d.price).Float64() + + d.applied = true + d.accountingPeer.shadowReservedBalance = new(big.Int).Sub(d.accountingPeer.shadowReservedBalance, d.price) + a.metrics.TotalDebitedAmount.Add(tot) + a.metrics.DebitEventsCount.Inc() + return nil + } + + // if surplus balance didn't cover full transaction, let's continue with leftover part as cost + debitIncrease := new(big.Int).Sub(d.price, surplusBalance) + + // conversion to uint64 is safe because we know the relationship between the values by now, but let's make a sanity check + if debitIncrease.Cmp(big.NewInt(0)) <= 0 { + return fmt.Errorf("sanity check failed for partial debit after surplus balance drawn") + } + cost.Set(debitIncrease) + + // if we still have something to debit, than have run out of surplus balance, + // let's store 0 as surplus balance + a.logger.Tracef("surplus debiting peer %v with value %d, new surplus balance is 0", d.peer, debitIncrease) + + err = a.store.Put(peerSurplusBalanceKey(d.peer), big.NewInt(0)) + if err != nil { + return fmt.Errorf("failed to persist surplus balance: %w", err) + } + + } + + currentBalance, err := a.Balance(d.peer) if err != nil { if !errors.Is(err, ErrPeerNoBalance) { - a.logger.Warningf("accounting: notifypaymentsent failed to load balance: %v", err) - return + return fmt.Errorf("failed to load balance: %w", err) } } // Get nextBalance by safely increasing current balance with price - nextBalance := new(big.Int).Add(currentBalance, amount) + nextBalance := new(big.Int).Add(currentBalance, cost) - a.logger.Tracef("registering payment sent to peer %v with amount %d, new balance is %d", peer, amount, nextBalance) + a.logger.Tracef("debiting peer %v with price %d, new balance is %d", d.peer, d.price, nextBalance) - err = a.store.Put(peerBalanceKey(peer), nextBalance) + err = a.store.Put(peerBalanceKey(d.peer), nextBalance) if err != nil { - a.logger.Warningf("accounting: notifypaymentsent failed to persist balance: %v", err) - return + return fmt.Errorf("failed to persist balance: %w", err) } + + d.applied = true + d.accountingPeer.shadowReservedBalance = new(big.Int).Sub(d.accountingPeer.shadowReservedBalance, d.price) + + tot, _ := big.NewFloat(0).SetInt(d.price).Float64() + + a.metrics.TotalDebitedAmount.Add(tot) + a.metrics.DebitEventsCount.Inc() + + if nextBalance.Cmp(new(big.Int).Add(a.paymentThreshold, a.paymentTolerance)) >= 0 { + // peer too much in debt + a.metrics.AccountingDisconnectsCount.Inc() + return p2p.NewBlockPeerError(10000*time.Hour, ErrDisconnectThresholdExceeded) + } + + return nil +} + +func (d *debitAction) Cleanup() { + if !d.applied { + d.accountingPeer.lock.Lock() + defer d.accountingPeer.lock.Unlock() + d.accountingPeer.shadowReservedBalance = new(big.Int).Sub(d.accountingPeer.shadowReservedBalance, d.price) + } +} + +func (a *Accounting) SetRefreshFunc(f RefreshFunc) { + a.refreshFunction = f } func (a *Accounting) SetPayFunc(f PayFunc) { diff --git a/pkg/accounting/accounting_test.go b/pkg/accounting/accounting_test.go index 3defaef9716..f4e6243dcd5 100644 --- a/pkg/accounting/accounting_test.go +++ b/pkg/accounting/accounting_test.go @@ -83,10 +83,12 @@ func TestAccountingAddBalance(t *testing.T) { } acc.Release(booking.peer, uint64(-booking.price)) } else { - err = acc.Debit(booking.peer, uint64(booking.price)) + debitAction := acc.PrepareDebit(booking.peer, uint64(booking.price)) + err = debitAction.Apply() if err != nil { t.Fatal(err) } + debitAction.Cleanup() } balance, err := acc.Balance(booking.peer) @@ -125,10 +127,12 @@ func TestAccountingAdd_persistentBalances(t *testing.T) { } peer1DebitAmount := testPrice - err = acc.Debit(peer1Addr, peer1DebitAmount) + debitAction := acc.PrepareDebit(peer1Addr, peer1DebitAmount) + err = debitAction.Apply() if err != nil { t.Fatal(err) } + debitAction.Cleanup() peer2CreditAmount := 2 * testPrice err = acc.Credit(peer2Addr, peer2CreditAmount) @@ -206,16 +210,20 @@ func TestAccountingDisconnect(t *testing.T) { } // put the peer 1 unit away from disconnect - err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()+testPaymentTolerance.Uint64()-1) + debitAction := acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64()+testPaymentTolerance.Uint64()-1) + err = debitAction.Apply() if err != nil { t.Fatal("expected no error while still within tolerance") } + debitAction.Cleanup() // put the peer over thee threshold - err = acc.Debit(peer1Addr, 1) + debitAction = acc.PrepareDebit(peer1Addr, 1) + err = debitAction.Apply() if err == nil { t.Fatal("expected Add to return error") } + debitAction.Cleanup() var e *p2p.BlockPeerError if !errors.As(err, &e) { @@ -415,10 +423,12 @@ func TestAccountingSurplusBalance(t *testing.T) { t.Fatal(err) } // Try Debiting a large amount to peer so balance is large positive - err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()-1) + debitAction := acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64()-1) + err = debitAction.Apply() if err != nil { t.Fatal(err) } + debitAction.Cleanup() // Notify of incoming payment from same peer, so balance goes to 0 with surplusbalance 2 err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).Add(testPaymentThreshold, big.NewInt(1))) if err != nil { @@ -462,10 +472,12 @@ func TestAccountingSurplusBalance(t *testing.T) { t.Fatal("Not expected balance, expected 0") } // Debit for same peer, so balance stays 0 with surplusbalance decreasing to 2 - err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()) + debitAction = acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64()) + err = debitAction.Apply() if err != nil { t.Fatal("Unexpected error from Credit") } + debitAction.Cleanup() // samity check surplus balance val, err = acc.SurplusBalance(peer1Addr) if err != nil { @@ -483,10 +495,12 @@ func TestAccountingSurplusBalance(t *testing.T) { t.Fatal("Not expected balance, expected 0") } // Debit for same peer, so balance goes to 9998 (testpaymentthreshold - 2) with surplusbalance decreasing to 0 - err = acc.Debit(peer1Addr, testPaymentThreshold.Uint64()) + debitAction = acc.PrepareDebit(peer1Addr, testPaymentThreshold.Uint64()) + err = debitAction.Apply() if err != nil { t.Fatal("Unexpected error from Debit") } + debitAction.Cleanup() // samity check surplus balance val, err = acc.SurplusBalance(peer1Addr) if err != nil { @@ -523,20 +537,24 @@ func TestAccountingNotifyPaymentReceived(t *testing.T) { } debtAmount := uint64(100) - err = acc.Debit(peer1Addr, debtAmount+testPaymentTolerance.Uint64()) + debitAction := acc.PrepareDebit(peer1Addr, debtAmount+testPaymentTolerance.Uint64()) + err = debitAction.Apply() if err != nil { t.Fatal(err) } + debitAction.Cleanup() err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64())) if err != nil { t.Fatal(err) } - err = acc.Debit(peer1Addr, debtAmount) + debitAction = acc.PrepareDebit(peer1Addr, debtAmount) + err = debitAction.Apply() if err != nil { t.Fatal(err) } + debitAction.Cleanup() err = acc.NotifyPaymentReceived(peer1Addr, new(big.Int).SetUint64(debtAmount+testPaymentTolerance.Uint64()+1)) if err != nil { @@ -677,10 +695,12 @@ func TestAccountingPeerDebt(t *testing.T) { peer1Addr := swarm.MustParseHexAddress("00112233") debt := uint64(1000) - err = acc.Debit(peer1Addr, debt) + debitAction := acc.PrepareDebit(peer1Addr, debt) + err = debitAction.Apply() if err != nil { t.Fatal(err) } + debitAction.Cleanup() actualDebt, err := acc.PeerDebt(peer1Addr) if err != nil { t.Fatal(err) diff --git a/pkg/accounting/mock/accounting.go b/pkg/accounting/mock/accounting.go index 8cb25cc7f26..6a93c420807 100644 --- a/pkg/accounting/mock/accounting.go +++ b/pkg/accounting/mock/accounting.go @@ -22,8 +22,9 @@ type Service struct { reserveFunc func(ctx context.Context, peer swarm.Address, price uint64) error releaseFunc func(peer swarm.Address, price uint64) creditFunc func(peer swarm.Address, price uint64) error - debitFunc func(peer swarm.Address, price uint64) error + prepareDebitFunc func(peer swarm.Address, price uint64) accounting.Action balanceFunc func(swarm.Address) (*big.Int, error) + shadowBalanceFunc func(swarm.Address) (*big.Int, error) balancesFunc func() (map[string]*big.Int, error) compensatedBalanceFunc func(swarm.Address) (*big.Int, error) compensatedBalancesFunc func() (map[string]*big.Int, error) @@ -31,6 +32,13 @@ type Service struct { balanceSurplusFunc func(swarm.Address) (*big.Int, error) } +type debitAction struct { + accounting *Service + price *big.Int + peer swarm.Address + applied bool +} + // WithReserveFunc sets the mock Reserve function func WithReserveFunc(f func(ctx context.Context, peer swarm.Address, price uint64) error) Option { return optionFunc(func(s *Service) { @@ -53,9 +61,9 @@ func WithCreditFunc(f func(peer swarm.Address, price uint64) error) Option { } // WithDebitFunc sets the mock Debit function -func WithDebitFunc(f func(peer swarm.Address, price uint64) error) Option { +func WithPrepareDebitFunc(f func(peer swarm.Address, price uint64) accounting.Action) Option { return optionFunc(func(s *Service) { - s.debitFunc = f + s.prepareDebitFunc = f }) } @@ -136,21 +144,36 @@ func (s *Service) Credit(peer swarm.Address, price uint64) error { } // Debit is the mock function wrapper that calls the set implementation -func (s *Service) Debit(peer swarm.Address, price uint64) error { - if s.debitFunc != nil { - return s.debitFunc(peer, price) +func (s *Service) PrepareDebit(peer swarm.Address, price uint64) accounting.Action { + if s.prepareDebitFunc != nil { + return s.prepareDebitFunc(peer, price) } - s.lock.Lock() - defer s.lock.Unlock() - if bal, ok := s.balances[peer.String()]; ok { - s.balances[peer.String()] = new(big.Int).Add(bal, new(big.Int).SetUint64(price)) + bigPrice := new(big.Int).SetUint64(price) + return &debitAction{ + accounting: s, + price: bigPrice, + peer: peer, + applied: false, + } + +} + +func (a *debitAction) Apply() error { + a.accounting.lock.Lock() + defer a.accounting.lock.Unlock() + + if bal, ok := a.accounting.balances[a.peer.String()]; ok { + a.accounting.balances[a.peer.String()] = new(big.Int).Add(bal, new(big.Int).Set(a.price)) } else { - s.balances[peer.String()] = new(big.Int).SetUint64(price) + a.accounting.balances[a.peer.String()] = new(big.Int).Set(a.price) } + return nil } +func (a *debitAction) Cleanup() {} + // Balance is the mock function wrapper that calls the set implementation func (s *Service) Balance(peer swarm.Address) (*big.Int, error) { if s.balanceFunc != nil { @@ -165,6 +188,19 @@ func (s *Service) Balance(peer swarm.Address) (*big.Int, error) { } } +func (s *Service) ShadowBalance(peer swarm.Address) (*big.Int, error) { + if s.shadowBalanceFunc != nil { + return s.shadowBalanceFunc(peer) + } + s.lock.Lock() + defer s.lock.Unlock() + if bal, ok := s.balances[peer.String()]; ok { + return new(big.Int).Neg(bal), nil + } else { + return big.NewInt(0), nil + } +} + // Balances is the mock function wrapper that calls the set implementation func (s *Service) Balances() (map[string]*big.Int, error) { if s.balancesFunc != nil { diff --git a/pkg/debugapi/debugapi.go b/pkg/debugapi/debugapi.go index 01b8291d774..76cb7e687c7 100644 --- a/pkg/debugapi/debugapi.go +++ b/pkg/debugapi/debugapi.go @@ -45,6 +45,7 @@ type Service struct { tags *tags.Tags accounting accounting.Interface settlement settlement.Interface + pseudo settlement.Interface chequebookEnabled bool chequebook chequebook.Service swap swap.ApiInterface @@ -97,6 +98,10 @@ func (s *Service) Configure(p2p p2p.DebugService, pingpong pingpong.Interface, t s.setRouter(s.newRouter()) } +func (s *Service) SetPseudo(pseudo settlement.Interface) { + s.pseudo = pseudo +} + // ServeHTTP implements http.Handler interface. func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // protect handler as it is changed by the Configure method diff --git a/pkg/debugapi/router.go b/pkg/debugapi/router.go index 2b9d01075c6..0de6ccc9e64 100644 --- a/pkg/debugapi/router.go +++ b/pkg/debugapi/router.go @@ -129,6 +129,10 @@ func (s *Service) newRouter() *mux.Router { "GET": http.HandlerFunc(s.settlementsHandler), }) + router.Handle("/pseudo", jsonhttp.MethodHandler{ + "GET": http.HandlerFunc(s.settlementsHandlerPseudo), + }) + router.Handle("/settlements/{peer}", jsonhttp.MethodHandler{ "GET": http.HandlerFunc(s.peerSettlementsHandler), }) diff --git a/pkg/debugapi/settlements.go b/pkg/debugapi/settlements.go index 408a5ee3083..d8991f97bbd 100644 --- a/pkg/debugapi/settlements.go +++ b/pkg/debugapi/settlements.go @@ -143,3 +143,59 @@ func (s *Service) peerSettlementsHandler(w http.ResponseWriter, r *http.Request) SettlementSent: sent, }) } + +func (s *Service) settlementsHandlerPseudo(w http.ResponseWriter, r *http.Request) { + + settlementsSent, err := s.pseudo.SettlementsSent() + if err != nil { + jsonhttp.InternalServerError(w, errCantSettlements) + s.logger.Debugf("debug api: sent settlements: %v", err) + s.logger.Error("debug api: can not get sent settlements") + return + } + settlementsReceived, err := s.pseudo.SettlementsReceived() + if err != nil { + jsonhttp.InternalServerError(w, errCantSettlements) + s.logger.Debugf("debug api: received settlements: %v", err) + s.logger.Error("debug api: can not get received settlements") + return + } + + totalReceived := big.NewInt(0) + totalSent := big.NewInt(0) + + settlementResponses := make(map[string]settlementResponse) + + for a, b := range settlementsSent { + settlementResponses[a] = settlementResponse{ + Peer: a, + SettlementSent: b, + SettlementReceived: big.NewInt(0), + } + totalSent.Add(b, totalSent) + } + + for a, b := range settlementsReceived { + if _, ok := settlementResponses[a]; ok { + t := settlementResponses[a] + t.SettlementReceived = b + settlementResponses[a] = t + } else { + settlementResponses[a] = settlementResponse{ + Peer: a, + SettlementSent: big.NewInt(0), + SettlementReceived: b, + } + } + totalReceived.Add(b, totalReceived) + } + + settlementResponsesArray := make([]settlementResponse, len(settlementResponses)) + i := 0 + for k := range settlementResponses { + settlementResponsesArray[i] = settlementResponses[k] + i++ + } + + jsonhttp.OK(w, settlementsResponse{TotalSettlementReceived: totalReceived, TotalSettlementSent: totalSent, Settlements: settlementResponsesArray}) +} diff --git a/pkg/node/node.go b/pkg/node/node.go index c7dffcb1dc9..216f06fc349 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -54,7 +54,6 @@ import ( "github.com/ethersphere/bee/pkg/recovery" "github.com/ethersphere/bee/pkg/resolver/multiresolver" "github.com/ethersphere/bee/pkg/retrieval" - settlement "github.com/ethersphere/bee/pkg/settlement" "github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/swap" "github.com/ethersphere/bee/pkg/settlement/swap/chequebook" @@ -134,6 +133,11 @@ type Options struct { BlockTime uint64 } +const ( + refreshRate = int64(1000000000000) + basePrice = 1000000000 +) + func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, signer crypto.Signer, networkID uint64, logger logging.Logger, libp2pPrivateKey, pssPrivateKey *ecdsa.PrivateKey, o Options) (b *Bee, err error) { tracer, tracerCloser, err := tracing.NewTracer(&tracing.Options{ Enabled: o.TracingEnabled, @@ -416,7 +420,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, } } - var settlement settlement.Interface var swapService *swap.Service kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, StandaloneMode: o.Standalone, BootnodeMode: o.BootnodeMode}) @@ -442,7 +445,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold) } - pricer := pricer.NewFixedPricer(swarmAddress, 1000000000) + pricer := pricer.NewFixedPricer(swarmAddress, basePrice) minThreshold := pricer.MostExpensive() @@ -469,6 +472,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, if !ok { return nil, fmt.Errorf("invalid payment early: %s", paymentEarly) } + acc, err := accounting.NewAccounting( paymentThreshold, paymentTolerance, @@ -476,11 +480,19 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, logger, stateStore, pricing, + big.NewInt(refreshRate), ) if err != nil { return nil, fmt.Errorf("accounting: %w", err) } + pseudosettleService := pseudosettle.New(p2ps, logger, stateStore, acc, big.NewInt(refreshRate), p2ps) + if err = p2ps.AddProtocol(pseudosettleService.Protocol()); err != nil { + return nil, fmt.Errorf("pseudosettle service: %w", err) + } + + acc.SetRefreshFunc(pseudosettleService.Pay) + if o.SwapEnable { swapService, err = InitSwap( p2ps, @@ -496,17 +508,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, if err != nil { return nil, err } - settlement = swapService - } else { - pseudosettleService := pseudosettle.New(p2ps, logger, stateStore, acc) - if err = p2ps.AddProtocol(pseudosettleService.Protocol()); err != nil { - return nil, fmt.Errorf("pseudosettle service: %w", err) - } - settlement = pseudosettleService + acc.SetPayFunc(swapService.Pay) } - acc.SetPayFunc(settlement.Pay) - pricing.SetPaymentThresholdObserver(acc) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer) @@ -641,12 +645,16 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, debugAPIService.MustRegisterMetrics(l.Metrics()...) } - if l, ok := settlement.(metrics.Collector); ok { - debugAPIService.MustRegisterMetrics(l.Metrics()...) + debugAPIService.MustRegisterMetrics(pseudosettleService.Metrics()...) + + if swapService != nil { + debugAPIService.MustRegisterMetrics(swapService.Metrics()...) } // inject dependencies and configure full debug api http path routes - debugAPIService.Configure(p2ps, pingPong, kad, lightNodes, storer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService, batchStore) + debugAPIService.Configure(p2ps, pingPong, kad, lightNodes, storer, tagService, acc, swapService, o.SwapEnable, swapService, chequebookService, batchStore) + + debugAPIService.SetPseudo(pseudosettleService) } if err := kad.Start(p2pCtx); err != nil { diff --git a/pkg/p2p/streamtest/streamtest.go b/pkg/p2p/streamtest/streamtest.go index f743b82412c..f20327b22f8 100644 --- a/pkg/p2p/streamtest/streamtest.go +++ b/pkg/p2p/streamtest/streamtest.go @@ -128,6 +128,8 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head defer close(record.done) // pass a new context to handler, + streamIn.responseHeaders = streamOut.headers + // do not cancel it with the client stream context err := handler(context.Background(), p2p.Peer{Address: r.base}, streamIn) if err != nil && err != io.EOF { diff --git a/pkg/pushsync/pushsync.go b/pkg/pushsync/pushsync.go index 0a96fbe16e5..f4496bc2f94 100644 --- a/pkg/pushsync/pushsync.go +++ b/pkg/pushsync/pushsync.go @@ -27,7 +27,7 @@ import ( "github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/tracing" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" opentracing "github.com/opentracing/opentracing-go" ) @@ -158,12 +158,15 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ps.logger.Errorf("pushsync: chunk store: %v", err) } + debit := ps.accounting.PrepareDebit(p.Address, price) + defer debit.Cleanup() + // return back receipt receipt := pb.Receipt{Address: chunk.Address().Bytes()} if err := w.WriteMsgWithContext(ctxd, &receipt); err != nil { return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) } - return ps.accounting.Debit(p.Address, price) + return debit.Apply() } return ErrOutOfDepthReplication @@ -283,23 +286,29 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) } // return back receipt + debit := ps.accounting.PrepareDebit(p.Address, price) + defer debit.Cleanup() + receipt := pb.Receipt{Address: chunk.Address().Bytes(), Signature: signature} if err := w.WriteMsgWithContext(ctx, &receipt); err != nil { return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) } - return ps.accounting.Debit(p.Address, price) + return debit.Apply() } return fmt.Errorf("handler: push to closest: %w", err) } + debit := ps.accounting.PrepareDebit(p.Address, price) + defer debit.Cleanup() + // pass back the receipt if err := w.WriteMsgWithContext(ctx, receipt); err != nil { return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err) } - return ps.accounting.Debit(p.Address, price) + return debit.Apply() } // PushChunkToClosest sends chunk to the closest peer by opening a stream. It then waits for diff --git a/pkg/retrieval/retrieval.go b/pkg/retrieval/retrieval.go index f565d902627..f01ceb7d6ad 100644 --- a/pkg/retrieval/retrieval.go +++ b/pkg/retrieval/retrieval.go @@ -362,6 +362,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e if err != nil { return fmt.Errorf("stamp marshal: %w", err) } + + chunkPrice := s.pricer.Price(chunk.Address()) + debit := s.accounting.PrepareDebit(p.Address, chunkPrice) + defer debit.Cleanup() + if err := w.WriteMsgWithContext(ctx, &pb.Delivery{ Data: chunk.Data(), Stamp: stamp, @@ -371,8 +376,6 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String()) - chunkPrice := s.pricer.Price(chunk.Address()) - // debit price from p's balance - return s.accounting.Debit(p.Address, chunkPrice) + return debit.Apply() } diff --git a/pkg/settlement/interface.go b/pkg/settlement/interface.go index e0aeb8d19de..832913e66a2 100644 --- a/pkg/settlement/interface.go +++ b/pkg/settlement/interface.go @@ -18,9 +18,6 @@ var ( // Interface is the interface used by Accounting to trigger settlement type Interface interface { - // Pay initiates a payment to the given peer - // It should return without error it is likely that the payment worked - Pay(ctx context.Context, peer swarm.Address, amount *big.Int) // TotalSent returns the total amount sent to a peer TotalSent(peer swarm.Address) (totalSent *big.Int, err error) // TotalReceived returns the total amount received from a peer @@ -35,4 +32,8 @@ type AccountingAPI interface { PeerDebt(peer swarm.Address) (*big.Int, error) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error NotifyPaymentSent(peer swarm.Address, amount *big.Int, receivedError error) + NotifyRefreshmentReceived(peer swarm.Address, amount *big.Int) error + + Reserve(ctx context.Context, peer swarm.Address, price uint64) error + Release(peer swarm.Address, price uint64) } diff --git a/pkg/settlement/pseudosettle/export_test.go b/pkg/settlement/pseudosettle/export_test.go new file mode 100644 index 00000000000..faecfd14283 --- /dev/null +++ b/pkg/settlement/pseudosettle/export_test.go @@ -0,0 +1,35 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pseudosettle + +import ( + "context" + "time" + + "github.com/ethersphere/bee/pkg/p2p" +) + +const ( + AllowanceFieldName = allowanceFieldName + TimestampFieldName = timestampFieldName +) + +func (s *Service) SetTimeNow(f func() time.Time) { + s.timeNow = f +} + +func (s *Service) SetTime(k int64) { + s.SetTimeNow(func() time.Time { + return time.Unix(k, 0) + }) +} + +func (s *Service) Init(ctx context.Context, peer p2p.Peer) error { + return s.init(ctx, peer) +} + +func (s *Service) Terminate(peer p2p.Peer) error { + return s.terminate(peer) +} diff --git a/pkg/settlement/pseudosettle/pb/pseudosettle.pb.go b/pkg/settlement/pseudosettle/pb/pseudosettle.pb.go index d242e9d698b..41d81fbd1c6 100644 --- a/pkg/settlement/pseudosettle/pb/pseudosettle.pb.go +++ b/pkg/settlement/pseudosettle/pb/pseudosettle.pb.go @@ -23,7 +23,7 @@ var _ = math.Inf const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type Payment struct { - Amount uint64 `protobuf:"varint,1,opt,name=Amount,proto3" json:"Amount,omitempty"` + Amount []byte `protobuf:"bytes,1,opt,name=Amount,proto3" json:"Amount,omitempty"` } func (m *Payment) Reset() { *m = Payment{} } @@ -59,29 +59,84 @@ func (m *Payment) XXX_DiscardUnknown() { var xxx_messageInfo_Payment proto.InternalMessageInfo -func (m *Payment) GetAmount() uint64 { +func (m *Payment) GetAmount() []byte { if m != nil { return m.Amount } + return nil +} + +type PaymentAck struct { + Amount []byte `protobuf:"bytes,1,opt,name=Amount,proto3" json:"Amount,omitempty"` + Timestamp int64 `protobuf:"varint,2,opt,name=Timestamp,proto3" json:"Timestamp,omitempty"` +} + +func (m *PaymentAck) Reset() { *m = PaymentAck{} } +func (m *PaymentAck) String() string { return proto.CompactTextString(m) } +func (*PaymentAck) ProtoMessage() {} +func (*PaymentAck) Descriptor() ([]byte, []int) { + return fileDescriptor_3ff21bb6c9cf5e84, []int{1} +} +func (m *PaymentAck) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *PaymentAck) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_PaymentAck.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *PaymentAck) XXX_Merge(src proto.Message) { + xxx_messageInfo_PaymentAck.Merge(m, src) +} +func (m *PaymentAck) XXX_Size() int { + return m.Size() +} +func (m *PaymentAck) XXX_DiscardUnknown() { + xxx_messageInfo_PaymentAck.DiscardUnknown(m) +} + +var xxx_messageInfo_PaymentAck proto.InternalMessageInfo + +func (m *PaymentAck) GetAmount() []byte { + if m != nil { + return m.Amount + } + return nil +} + +func (m *PaymentAck) GetTimestamp() int64 { + if m != nil { + return m.Timestamp + } return 0 } func init() { proto.RegisterType((*Payment)(nil), "pseudosettle.Payment") + proto.RegisterType((*PaymentAck)(nil), "pseudosettle.PaymentAck") } func init() { proto.RegisterFile("pseudosettle.proto", fileDescriptor_3ff21bb6c9cf5e84) } var fileDescriptor_3ff21bb6c9cf5e84 = []byte{ - // 114 bytes of a gzipped FileDescriptorProto + // 148 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x28, 0x4e, 0x2d, 0x4d, 0xc9, 0x2f, 0x4e, 0x2d, 0x29, 0xc9, 0x49, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x41, 0x16, 0x53, 0x52, 0xe4, 0x62, 0x0f, 0x48, 0xac, 0xcc, 0x4d, 0xcd, 0x2b, 0x11, 0x12, 0xe3, - 0x62, 0x73, 0xcc, 0xcd, 0x2f, 0xcd, 0x2b, 0x91, 0x60, 0x54, 0x60, 0xd4, 0x60, 0x09, 0x82, 0xf2, - 0x9c, 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, - 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xa9, 0x20, 0x29, - 0x89, 0x0d, 0x6c, 0xaa, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0xfb, 0x97, 0x5c, 0xf8, 0x6b, 0x00, - 0x00, 0x00, + 0x62, 0x73, 0xcc, 0xcd, 0x2f, 0xcd, 0x2b, 0x91, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x82, 0xf2, + 0x94, 0x9c, 0xb8, 0xb8, 0xa0, 0x4a, 0x1c, 0x93, 0xb3, 0x71, 0xa9, 0x12, 0x92, 0xe1, 0xe2, 0x0c, + 0xc9, 0xcc, 0x4d, 0x2d, 0x2e, 0x49, 0xcc, 0x2d, 0x90, 0x60, 0x52, 0x60, 0xd4, 0x60, 0x0e, 0x42, + 0x08, 0x38, 0xc9, 0x9c, 0x78, 0x24, 0xc7, 0x78, 0xe1, 0x91, 0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, + 0x13, 0x1e, 0xcb, 0x31, 0x5c, 0x78, 0x2c, 0xc7, 0x70, 0xe3, 0xb1, 0x1c, 0x43, 0x14, 0x53, 0x41, + 0x52, 0x12, 0x1b, 0xd8, 0x65, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x87, 0xcb, 0xb8, 0x18, + 0xaf, 0x00, 0x00, 0x00, } func (m *Payment) Marshal() (dAtA []byte, err error) { @@ -104,10 +159,47 @@ func (m *Payment) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.Amount != 0 { - i = encodeVarintPseudosettle(dAtA, i, uint64(m.Amount)) + if len(m.Amount) > 0 { + i -= len(m.Amount) + copy(dAtA[i:], m.Amount) + i = encodeVarintPseudosettle(dAtA, i, uint64(len(m.Amount))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *PaymentAck) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PaymentAck) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *PaymentAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Timestamp != 0 { + i = encodeVarintPseudosettle(dAtA, i, uint64(m.Timestamp)) i-- - dAtA[i] = 0x8 + dAtA[i] = 0x10 + } + if len(m.Amount) > 0 { + i -= len(m.Amount) + copy(dAtA[i:], m.Amount) + i = encodeVarintPseudosettle(dAtA, i, uint64(len(m.Amount))) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -129,8 +221,25 @@ func (m *Payment) Size() (n int) { } var l int _ = l - if m.Amount != 0 { - n += 1 + sovPseudosettle(uint64(m.Amount)) + l = len(m.Amount) + if l > 0 { + n += 1 + l + sovPseudosettle(uint64(l)) + } + return n +} + +func (m *PaymentAck) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Amount) + if l > 0 { + n += 1 + l + sovPseudosettle(uint64(l)) + } + if m.Timestamp != 0 { + n += 1 + sovPseudosettle(uint64(m.Timestamp)) } return n } @@ -171,10 +280,131 @@ func (m *Payment) Unmarshal(dAtA []byte) error { } switch fieldNum { case 1: - if wireType != 0 { + if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Amount", wireType) } - m.Amount = 0 + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPseudosettle + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthPseudosettle + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthPseudosettle + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Amount = append(m.Amount[:0], dAtA[iNdEx:postIndex]...) + if m.Amount == nil { + m.Amount = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipPseudosettle(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthPseudosettle + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthPseudosettle + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *PaymentAck) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPseudosettle + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PaymentAck: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PaymentAck: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Amount", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPseudosettle + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthPseudosettle + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthPseudosettle + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Amount = append(m.Amount[:0], dAtA[iNdEx:postIndex]...) + if m.Amount == nil { + m.Amount = []byte{} + } + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Timestamp", wireType) + } + m.Timestamp = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowPseudosettle @@ -184,7 +414,7 @@ func (m *Payment) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.Amount |= uint64(b&0x7F) << shift + m.Timestamp |= int64(b&0x7F) << shift if b < 0x80 { break } diff --git a/pkg/settlement/pseudosettle/pb/pseudosettle.proto b/pkg/settlement/pseudosettle/pb/pseudosettle.proto index d5bbc8c6cbe..c5b9b1e35b6 100644 --- a/pkg/settlement/pseudosettle/pb/pseudosettle.proto +++ b/pkg/settlement/pseudosettle/pb/pseudosettle.proto @@ -9,5 +9,10 @@ package pseudosettle; option go_package = "pb"; message Payment { - uint64 Amount = 1; + bytes Amount = 1; } + +message PaymentAck { + bytes Amount = 1; + int64 Timestamp = 2; +} \ No newline at end of file diff --git a/pkg/settlement/pseudosettle/pseudosettle.go b/pkg/settlement/pseudosettle/pseudosettle.go index ebcd983fa14..47ee851a1ce 100644 --- a/pkg/settlement/pseudosettle/pseudosettle.go +++ b/pkg/settlement/pseudosettle/pseudosettle.go @@ -10,6 +10,7 @@ import ( "fmt" "math/big" "strings" + "sync" "time" "github.com/ethersphere/bee/pkg/logging" @@ -30,6 +31,11 @@ const ( var ( SettlementReceivedPrefix = "pseudosettle_total_received_" SettlementSentPrefix = "pseudosettle_total_sent_" + + ErrSettlementTooSoon = errors.New("settlement too soon") + ErrNoPseudoSettlePeer = errors.New("settlement peer not found") + ErrDisconnectAllowanceCheckFailed = errors.New("settlement allowance below enforced amount") + ErrTimeOutOfSync = errors.New("settlement allowance timestamps differ beyond tolerance") ) type Service struct { @@ -38,15 +44,34 @@ type Service struct { store storage.StateStorer accountingAPI settlement.AccountingAPI metrics metrics + refreshRate *big.Int + p2pService p2p.Service + timeNow func() time.Time + peersMu sync.Mutex + peers map[string]*pseudoSettlePeer +} + +type pseudoSettlePeer struct { + lock sync.Mutex // lock to be held during receiving a payment from this peer } -func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer, accountingAPI settlement.AccountingAPI) *Service { +type lastPayment struct { + Timestamp int64 + CheckTimestamp int64 + Total *big.Int +} + +func New(streamer p2p.Streamer, logger logging.Logger, store storage.StateStorer, accountingAPI settlement.AccountingAPI, refreshRate *big.Int, p2pService p2p.Service) *Service { return &Service{ streamer: streamer, logger: logger, metrics: newMetrics(), store: store, accountingAPI: accountingAPI, + p2pService: p2pService, + refreshRate: refreshRate, + timeNow: time.Now, + peers: make(map[string]*pseudoSettlePeer), } } @@ -60,7 +85,32 @@ func (s *Service) Protocol() p2p.ProtocolSpec { Handler: s.handler, }, }, + ConnectIn: s.init, + ConnectOut: s.init, + DisconnectIn: s.terminate, + DisconnectOut: s.terminate, + } +} + +func (s *Service) init(ctx context.Context, p p2p.Peer) error { + s.peersMu.Lock() + defer s.peersMu.Unlock() + + _, ok := s.peers[p.Address.String()] + if !ok { + peerData := &pseudoSettlePeer{} + s.peers[p.Address.String()] = peerData } + + return nil +} + +func (s *Service) terminate(p p2p.Peer) error { + s.peersMu.Lock() + defer s.peersMu.Unlock() + + delete(s.peers, p.Address.String()) + return nil } func totalKey(peer swarm.Address, prefix string) string { @@ -77,13 +127,45 @@ func totalKeyPeer(key []byte, prefix string) (peer swarm.Address, err error) { return swarm.ParseHexAddress(split[1]) } +// peerAllowance computes the maximum incoming payment value we accept +// this is the time based allowance or the peers actual debt, whichever is less +func (s *Service) peerAllowance(peer swarm.Address) (limit *big.Int, stamp int64, err error) { + var lastTime lastPayment + err = s.store.Get(totalKey(peer, SettlementReceivedPrefix), &lastTime) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + return nil, 0, err + } + lastTime.Timestamp = int64(0) + } + + currentTime := s.timeNow().Unix() + if currentTime == lastTime.Timestamp { + return nil, 0, ErrSettlementTooSoon + } + + maxAllowance := new(big.Int).Mul(big.NewInt(currentTime-lastTime.Timestamp), s.refreshRate) + + peerDebt, err := s.accountingAPI.PeerDebt(peer) + if err != nil { + return nil, 0, err + } + + if peerDebt.Cmp(maxAllowance) >= 0 { + return maxAllowance, currentTime, nil + } + + return peerDebt, currentTime, nil +} + func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) { - r := protobuf.NewReader(stream) + // lock peer here + w, r := protobuf.NewWriterAndReader(stream) defer func() { if err != nil { _ = stream.Reset() } else { - _ = stream.FullClose() + go stream.FullClose() } }() var req pb.Payment @@ -91,70 +173,182 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e return fmt.Errorf("read request from peer %v: %w", p.Address, err) } - s.metrics.TotalReceivedPseudoSettlements.Add(float64(req.Amount)) - s.logger.Tracef("received payment message from peer %v of %d", p.Address, req.Amount) + attemptedAmount := big.NewInt(0).SetBytes(req.Amount) + + paymentAmount := new(big.Int).Set(attemptedAmount) + + s.peersMu.Lock() + pseudoSettlePeer, ok := s.peers[p.Address.String()] + if !ok { + s.peersMu.Unlock() + return ErrNoPseudoSettlePeer + } + s.peersMu.Unlock() + + pseudoSettlePeer.lock.Lock() + defer pseudoSettlePeer.lock.Unlock() + + allowance, timestamp, err := s.peerAllowance(p.Address) + if err != nil { + return err + } + + if allowance.Cmp(attemptedAmount) < 0 { + paymentAmount.Set(allowance) + s.logger.Tracef("pseudosettle accepting reduced payment from peer %v of %d", p.Address, paymentAmount) + } else { + s.logger.Tracef("pseudosettle accepting payment message from peer %v of %d", p.Address, paymentAmount) + } + + if paymentAmount.Cmp(big.NewInt(0)) < 0 { + paymentAmount.Set(big.NewInt(0)) + } + + err = s.accountingAPI.Reserve(ctx, p.Address, paymentAmount.Uint64()) + if err != nil { + return err + } + defer s.accountingAPI.Release(p.Address, paymentAmount.Uint64()) + + err = w.WriteMsgWithContext(ctx, &pb.PaymentAck{ + Amount: paymentAmount.Bytes(), + Timestamp: timestamp, + }) + if err != nil { + return err + } - totalReceived, err := s.TotalReceived(p.Address) + var lastTime lastPayment + err = s.store.Get(totalKey(p.Address, SettlementReceivedPrefix), &lastTime) if err != nil { - if !errors.Is(err, settlement.ErrPeerNoSettlements) { + if !errors.Is(err, storage.ErrNotFound) { return err } - totalReceived = big.NewInt(0) + lastTime.Total = big.NewInt(0) } - err = s.store.Put(totalKey(p.Address, SettlementReceivedPrefix), totalReceived.Add(totalReceived, new(big.Int).SetUint64(req.Amount))) + lastTime.Total = lastTime.Total.Add(lastTime.Total, paymentAmount) + lastTime.Timestamp = timestamp + + err = s.store.Put(totalKey(p.Address, SettlementReceivedPrefix), lastTime) if err != nil { return err } - return s.accountingAPI.NotifyPaymentReceived(p.Address, new(big.Int).SetUint64(req.Amount)) + receivedPaymentF64, _ := big.NewFloat(0).SetInt(paymentAmount).Float64() + s.metrics.TotalReceivedPseudoSettlements.Add(receivedPaymentF64) + return s.accountingAPI.NotifyRefreshmentReceived(p.Address, paymentAmount) } // Pay initiates a payment to the given peer -func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) { +func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int, checkAllowance *big.Int) (*big.Int, int64, error) { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() + var err error - defer func() { - if err != nil { - s.accountingAPI.NotifyPaymentSent(peer, nil, err) + + var lastTime lastPayment + err = s.store.Get(totalKey(peer, SettlementSentPrefix), &lastTime) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + return nil, 0, err } - }() + lastTime.Total = big.NewInt(0) + lastTime.Timestamp = 0 + } + + currentTime := s.timeNow().Unix() + if currentTime == lastTime.Timestamp { + return nil, 0, ErrSettlementTooSoon + } + stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) if err != nil { - return + return nil, 0, err } defer func() { if err != nil { _ = stream.Reset() } else { - go stream.FullClose() + _ = stream.FullClose() } }() - s.logger.Tracef("sending payment message to peer %v of %d", peer, amount) - w := protobuf.NewWriter(stream) + if checkAllowance.Cmp(amount) > 0 { + checkAllowance.Set(amount) + } + + s.logger.Tracef("pseudosettle sending payment message to peer %v of %d", peer, amount) + w, r := protobuf.NewWriterAndReader(stream) + err = w.WriteMsgWithContext(ctx, &pb.Payment{ - Amount: amount.Uint64(), + Amount: amount.Bytes(), }) if err != nil { - return + return nil, 0, err } - totalSent, err := s.TotalSent(peer) + + checkTime := s.timeNow().Unix() + + var paymentAck pb.PaymentAck + err = r.ReadMsgWithContext(ctx, &paymentAck) if err != nil { - if !errors.Is(err, settlement.ErrPeerNoSettlements) { - return + return nil, 0, err + } + + acceptedAmount := new(big.Int).SetBytes(paymentAck.Amount) + if acceptedAmount.Cmp(amount) > 0 { + err = fmt.Errorf("pseudosettle other peer %v accepted payment larger than expected", peer) + return nil, 0, err + } + + experiencedInterval := checkTime - lastTime.CheckTimestamp + allegedInterval := paymentAck.Timestamp - lastTime.Timestamp + + if allegedInterval < 0 { + return nil, 0, ErrTimeOutOfSync + } + + experienceDifferenceRecent := paymentAck.Timestamp - checkTime + + if experienceDifferenceRecent < -2 || experienceDifferenceRecent > 2 { + return nil, 0, ErrTimeOutOfSync + } + + experienceDifferenceInterval := experiencedInterval - allegedInterval + if experienceDifferenceInterval < -3 || experienceDifferenceInterval > 3 { + return nil, 0, ErrTimeOutOfSync + } + + // enforce allowance + // check if value is appropriate + expectedAllowance := new(big.Int).Mul(big.NewInt(allegedInterval), s.refreshRate) + if expectedAllowance.Cmp(checkAllowance) > 0 { + expectedAllowance = new(big.Int).Set(checkAllowance) + } + + if expectedAllowance.Cmp(acceptedAmount) > 0 { + // disconnect peer + err = s.p2pService.Blocklist(peer, 10000*time.Hour) + if err != nil { + return nil, 0, err } - totalSent = big.NewInt(0) + return nil, 0, ErrDisconnectAllowanceCheckFailed } - err = s.store.Put(totalKey(peer, SettlementSentPrefix), totalSent.Add(totalSent, amount)) + lastTime.Total = lastTime.Total.Add(lastTime.Total, acceptedAmount) + lastTime.Timestamp = paymentAck.Timestamp + lastTime.CheckTimestamp = checkTime + + err = s.store.Put(totalKey(peer, SettlementSentPrefix), lastTime) if err != nil { - return + return nil, 0, err } - s.accountingAPI.NotifyPaymentSent(peer, amount, nil) - amountFloat, _ := new(big.Float).SetInt(amount).Float64() + + amountFloat, _ := new(big.Float).SetInt(acceptedAmount).Float64() s.metrics.TotalSentPseudoSettlements.Add(amountFloat) + + return acceptedAmount, lastTime.CheckTimestamp, nil } func (s *Service) SetAccountingAPI(accountingAPI settlement.AccountingAPI) { @@ -163,28 +357,32 @@ func (s *Service) SetAccountingAPI(accountingAPI settlement.AccountingAPI) { // TotalSent returns the total amount sent to a peer func (s *Service) TotalSent(peer swarm.Address) (totalSent *big.Int, err error) { - key := totalKey(peer, SettlementSentPrefix) - err = s.store.Get(key, &totalSent) + var lastTime lastPayment + + err = s.store.Get(totalKey(peer, SettlementSentPrefix), &lastTime) if err != nil { - if errors.Is(err, storage.ErrNotFound) { + if !errors.Is(err, storage.ErrNotFound) { return nil, settlement.ErrPeerNoSettlements } - return nil, err + lastTime.Total = big.NewInt(0) } - return totalSent, nil + + return lastTime.Total, nil } // TotalReceived returns the total amount received from a peer func (s *Service) TotalReceived(peer swarm.Address) (totalReceived *big.Int, err error) { - key := totalKey(peer, SettlementReceivedPrefix) - err = s.store.Get(key, &totalReceived) + var lastTime lastPayment + + err = s.store.Get(totalKey(peer, SettlementReceivedPrefix), &lastTime) if err != nil { - if errors.Is(err, storage.ErrNotFound) { + if !errors.Is(err, storage.ErrNotFound) { return nil, settlement.ErrPeerNoSettlements } - return nil, err + lastTime.Total = big.NewInt(0) } - return totalReceived, nil + + return lastTime.Total, nil } // SettlementsSent returns all stored sent settlement values for a given type of prefix @@ -196,13 +394,13 @@ func (s *Service) SettlementsSent() (map[string]*big.Int, error) { return false, fmt.Errorf("parse address from key: %s: %w", string(key), err) } if _, ok := sent[addr.String()]; !ok { - var storevalue *big.Int + var storevalue lastPayment err = s.store.Get(totalKey(addr, SettlementSentPrefix), &storevalue) if err != nil { return false, fmt.Errorf("get peer %s settlement balance: %w", addr.String(), err) } - sent[addr.String()] = storevalue + sent[addr.String()] = storevalue.Total } return false, nil }) @@ -221,13 +419,13 @@ func (s *Service) SettlementsReceived() (map[string]*big.Int, error) { return false, fmt.Errorf("parse address from key: %s: %w", string(key), err) } if _, ok := received[addr.String()]; !ok { - var storevalue *big.Int + var storevalue lastPayment err = s.store.Get(totalKey(addr, SettlementReceivedPrefix), &storevalue) if err != nil { return false, fmt.Errorf("get peer %s settlement balance: %w", addr.String(), err) } - received[addr.String()] = storevalue + received[addr.String()] = storevalue.Total } return false, nil }) diff --git a/pkg/settlement/pseudosettle/pseudosettle_test.go b/pkg/settlement/pseudosettle/pseudosettle_test.go index a3c2ce13801..a7b23ca952e 100644 --- a/pkg/settlement/pseudosettle/pseudosettle_test.go +++ b/pkg/settlement/pseudosettle/pseudosettle_test.go @@ -7,12 +7,15 @@ package pseudosettle_test import ( "bytes" "context" + "errors" "io/ioutil" "math/big" "testing" "time" "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/p2p" + mockp2p "github.com/ethersphere/bee/pkg/p2p/mock" "github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/settlement/pseudosettle" @@ -22,8 +25,10 @@ import ( ) type testObserver struct { - receivedCalled chan notifyPaymentReceivedCall - sentCalled chan notifyPaymentSentCall + receivedCalled chan notifyPaymentReceivedCall + sentCalled chan notifyPaymentSentCall + peerDebts map[string]*big.Int + peerShadowBalances map[string]*big.Int } type notifyPaymentReceivedCall struct { @@ -37,15 +42,37 @@ type notifyPaymentSentCall struct { err error } -func newTestObserver() *testObserver { +func newTestObserver(debtAmounts map[string]*big.Int, shadowBalanceAmounts map[string]*big.Int) *testObserver { return &testObserver{ - receivedCalled: make(chan notifyPaymentReceivedCall, 1), - sentCalled: make(chan notifyPaymentSentCall, 1), + receivedCalled: make(chan notifyPaymentReceivedCall, 1), + sentCalled: make(chan notifyPaymentSentCall, 1), + peerDebts: debtAmounts, + peerShadowBalances: shadowBalanceAmounts, } } +func (t *testObserver) setPeerDebt(peer swarm.Address, debt *big.Int) { + t.peerDebts[peer.String()] = debt +} + +func (t *testObserver) setPeerShadowBalance(peer swarm.Address, debt *big.Int) { + t.peerShadowBalances[peer.String()] = debt +} + func (t *testObserver) PeerDebt(peer swarm.Address) (*big.Int, error) { - return nil, nil + if debt, ok := t.peerDebts[peer.String()]; ok { + return debt, nil + } + + return nil, errors.New("Peer not listed") +} + +func (t *testObserver) ShadowBalance(peer swarm.Address) (*big.Int, error) { + if debt, ok := t.peerShadowBalances[peer.String()]; ok { + return debt, nil + } + + return nil, errors.New("Peer not listed") } func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error { @@ -63,16 +90,157 @@ func (t *testObserver) NotifyPaymentSent(peer swarm.Address, amount *big.Int, er err: err, } } + +var testRefreshRate = int64(10000) + func TestPayment(t *testing.T) { logger := logging.New(ioutil.Discard, 0) storeRecipient := mock.NewStateStore() defer storeRecipient.Close() - observer := newTestObserver() - recipient := pseudosettle.New(nil, logger, storeRecipient, observer) + peerID := swarm.MustParseHexAddress("9ee7add7") + peer := p2p.Peer{Address: peerID} + + debt := int64(10000) + + observer := newTestObserver(map[string]*big.Int{peerID.String(): big.NewInt(debt)}, map[string]*big.Int{}) + recipient := pseudosettle.New(nil, logger, storeRecipient, observer, big.NewInt(testRefreshRate), mockp2p.New()) + recipient.SetAccountingAPI(observer) + err := recipient.Init(context.Background(), peer) + if err != nil { + t.Fatal(err) + } + + recorder := streamtest.New( + streamtest.WithProtocols(recipient.Protocol()), + streamtest.WithBaseAddr(peerID), + ) + + storePayer := mock.NewStateStore() + defer storePayer.Close() + + observer2 := newTestObserver(map[string]*big.Int{}, map[string]*big.Int{peerID.String(): big.NewInt(debt)}) + payer := pseudosettle.New(recorder, logger, storePayer, observer2, big.NewInt(testRefreshRate), mockp2p.New()) + payer.SetAccountingAPI(observer2) + + amount := big.NewInt(debt) + + payer.Pay(context.Background(), peerID, amount) + + records, err := recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 1 { + t.Fatalf("got %v records, want %v", l, 1) + } + + record := records[0] + + if err := record.Err(); err != nil { + t.Fatalf("record error: %v", err) + } + + messages, err := protobuf.ReadMessages( + bytes.NewReader(record.In()), + func() protobuf.Message { return new(pb.Payment) }, + ) + if err != nil { + t.Fatal(err) + } + + receivedMessages, err := protobuf.ReadMessages( + bytes.NewReader(record.Out()), + func() protobuf.Message { return new(pb.PaymentAck) }, + ) + if err != nil { + t.Fatal(err) + } + + if len(messages) != 1 || len(receivedMessages) != 1 { + t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1) + } + + sentAmount := big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount) + receivedAmount := big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount) + if sentAmount.Cmp(amount) != 0 { + t.Fatalf("got message with amount %v, want %v", sentAmount, amount) + } + + if sentAmount.Cmp(receivedAmount) != 0 { + t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, sentAmount) + } + + select { + case call := <-observer.receivedCalled: + if call.amount.Cmp(amount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + select { + case call := <-observer2.sentCalled: + if call.amount.Cmp(amount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err != nil { + t.Fatalf("observer called with error. got %v want nil", call.err) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + totalSent, err := payer.TotalSent(peerID) + if err != nil { + t.Fatal(err) + } + + if totalSent.Cmp(sentAmount) != 0 { + t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentAmount) + } + + totalReceived, err := recipient.TotalReceived(peerID) + if err != nil { + t.Fatal(err) + } + + if totalReceived.Cmp(sentAmount) != 0 { + t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentAmount) + } +} + +func TestTimeLimitedPayment(t *testing.T) { + logger := logging.New(ioutil.Discard, 0) + + storeRecipient := mock.NewStateStore() + defer storeRecipient.Close() peerID := swarm.MustParseHexAddress("9ee7add7") + peer := p2p.Peer{Address: peerID} + + debt := testRefreshRate + + observer := newTestObserver(map[string]*big.Int{peerID.String(): big.NewInt(debt)}, map[string]*big.Int{}) + recipient := pseudosettle.New(nil, logger, storeRecipient, observer, big.NewInt(testRefreshRate), mockp2p.New()) + recipient.SetAccountingAPI(observer) + err := recipient.Init(context.Background(), peer) + if err != nil { + t.Fatal(err) + } recorder := streamtest.New( streamtest.WithProtocols(recipient.Protocol()), @@ -82,11 +250,14 @@ func TestPayment(t *testing.T) { storePayer := mock.NewStateStore() defer storePayer.Close() - observer2 := newTestObserver() - payer := pseudosettle.New(recorder, logger, storePayer, observer2) + observer2 := newTestObserver(map[string]*big.Int{}, map[string]*big.Int{peerID.String(): big.NewInt(debt)}) + payer := pseudosettle.New(recorder, logger, storePayer, observer2, big.NewInt(testRefreshRate), mockp2p.New()) payer.SetAccountingAPI(observer2) - amount := big.NewInt(10000) + payer.SetTime(int64(10000)) + recipient.SetTime(int64(10000)) + + amount := big.NewInt(debt) payer.Pay(context.Background(), peerID, amount) @@ -113,15 +284,28 @@ func TestPayment(t *testing.T) { t.Fatal(err) } - if len(messages) != 1 { - t.Fatalf("got %v messages, want %v", len(messages), 1) + receivedMessages, err := protobuf.ReadMessages( + bytes.NewReader(record.Out()), + func() protobuf.Message { return new(pb.PaymentAck) }, + ) + if err != nil { + t.Fatal(err) + } + + if len(messages) != 1 || len(receivedMessages) != 1 { + t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1) } - sentAmount := messages[0].(*pb.Payment).Amount - if sentAmount != amount.Uint64() { + sentAmount := big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount) + receivedAmount := big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount) + if sentAmount.Cmp(amount) != 0 { t.Fatalf("got message with amount %v, want %v", sentAmount, amount) } + if sentAmount.Cmp(receivedAmount) != 0 { + t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, sentAmount) + } + select { case call := <-observer.receivedCalled: if call.amount.Cmp(amount) != 0 { @@ -158,7 +342,7 @@ func TestPayment(t *testing.T) { t.Fatal(err) } - if totalSent.Cmp(new(big.Int).SetUint64(sentAmount)) != 0 { + if totalSent.Cmp(sentAmount) != 0 { t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentAmount) } @@ -167,7 +351,528 @@ func TestPayment(t *testing.T) { t.Fatal(err) } - if totalReceived.Cmp(new(big.Int).SetUint64(sentAmount)) != 0 { + if totalReceived.Cmp(sentAmount) != 0 { t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentAmount) } + + sentSum := big.NewInt(testRefreshRate) + + // Let 3 seconds pass, attempt settlement below time based refreshment rate + + debt = testRefreshRate * 3 / 2 + amount = big.NewInt(debt) + + payer.SetTime(int64(10003)) + recipient.SetTime(int64(10003)) + + observer.setPeerDebt(peerID, amount) + observer2.setPeerShadowBalance(peerID, amount) + + payer.Pay(context.Background(), peerID, amount) + + sentSum = sentSum.Add(sentSum, amount) + + records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 2 { + t.Fatalf("got %v records, want %v", l, 2) + } + record = records[1] + + if err := record.Err(); err != nil { + t.Fatalf("record error: %v", err) + } + + messages, err = protobuf.ReadMessages( + bytes.NewReader(record.In()), + func() protobuf.Message { return new(pb.Payment) }, + ) + if err != nil { + t.Fatal(err) + } + + receivedMessages, err = protobuf.ReadMessages( + bytes.NewReader(record.Out()), + func() protobuf.Message { return new(pb.PaymentAck) }, + ) + if err != nil { + t.Fatal(err) + } + + if len(messages) != 1 || len(receivedMessages) != 1 { + t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1) + } + + sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount) + receivedAmount = big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount) + if sentAmount.Cmp(amount) != 0 { + t.Fatalf("got message with amount %v, want %v", sentAmount, amount) + } + + if sentAmount.Cmp(receivedAmount) != 0 { + t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, sentAmount) + } + + select { + case call := <-observer.receivedCalled: + if call.amount.Cmp(receivedAmount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + select { + case call := <-observer2.sentCalled: + if call.amount.Cmp(receivedAmount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, amount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err != nil { + t.Fatalf("observer called with error. got %v want nil", call.err) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + totalSent, err = payer.TotalSent(peerID) + if err != nil { + t.Fatal(err) + } + + if totalSent.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum) + } + + totalReceived, err = recipient.TotalReceived(peerID) + if err != nil { + t.Fatal(err) + } + + if totalReceived.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum) + } + + // attempt settlement over the time-based allowed limit 1 seconds later + + debt = 3 * testRefreshRate + amount = big.NewInt(debt) + + payer.SetTime(int64(10004)) + recipient.SetTime(int64(10004)) + + observer.setPeerDebt(peerID, amount) + observer2.setPeerShadowBalance(peerID, amount) + + payer.Pay(context.Background(), peerID, amount) + + testRefreshRateBigInt := big.NewInt(testRefreshRate) + + sentSum = sentSum.Add(sentSum, testRefreshRateBigInt) + + records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 3 { + t.Fatalf("got %v records, want %v", l, 3) + } + + record = records[2] + + if err := record.Err(); err != nil { + t.Fatalf("record error: %v", err) + } + + messages, err = protobuf.ReadMessages( + bytes.NewReader(record.In()), + func() protobuf.Message { return new(pb.Payment) }, + ) + if err != nil { + t.Fatal(err) + } + + receivedMessages, err = protobuf.ReadMessages( + bytes.NewReader(record.Out()), + func() protobuf.Message { return new(pb.PaymentAck) }, + ) + if err != nil { + t.Fatal(err) + } + + if len(messages) != 1 || len(receivedMessages) != 1 { + t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1) + } + + sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount) + receivedAmount = big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount) + if sentAmount.Cmp(amount) != 0 { + t.Fatalf("got message with amount %v, want %v", sentAmount, amount) + } + + if receivedAmount.Cmp(testRefreshRateBigInt) != 0 { + t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, testRefreshRateBigInt) + } + + select { + case call := <-observer.receivedCalled: + if call.amount.Cmp(testRefreshRateBigInt) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testRefreshRate) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + select { + case call := <-observer2.sentCalled: + if call.amount.Cmp(testRefreshRateBigInt) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testRefreshRate) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err != nil { + t.Fatalf("observer called with error. got %v want nil", call.err) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + totalSent, err = payer.TotalSent(peerID) + if err != nil { + t.Fatal(err) + } + + if totalSent.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum) + } + + totalReceived, err = recipient.TotalReceived(peerID) + if err != nil { + t.Fatal(err) + } + + if totalReceived.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum) + } + + // attempt settle again in the same second without success + + debt = 4 * testRefreshRate + amount = big.NewInt(debt) + + observer.setPeerDebt(peerID, amount) + observer2.setPeerShadowBalance(peerID, amount) + + payer.Pay(context.Background(), peerID, amount) + + records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 3 { + t.Fatalf("got %v records, want %v", l, 3) + } + + select { + case <-observer.receivedCalled: + t.Fatal("unexpected observer to be called") + + case <-time.After(time.Second): + + } + + select { + case call := <-observer2.sentCalled: + if call.amount != nil { + t.Fatalf("observer called with wrong amount. got %d, want nil", call.amount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err == nil { + t.Fatalf("observer called without error. got nil want err") + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + // attempt again while recipient is still supposed to be blocking based on time + + debt = 2 * testRefreshRate + amount = big.NewInt(debt) + + payer.SetTime(int64(10005)) + recipient.SetTime(int64(10004)) + + observer.setPeerDebt(peerID, amount) + observer2.setPeerShadowBalance(peerID, amount) + + payer.Pay(context.Background(), peerID, amount) + + records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 4 { + t.Fatalf("got %v records, want %v", l, 4) + } + + select { + case <-observer.receivedCalled: + t.Fatal("unexpected observer to be called") + + case <-time.After(time.Second): + + } + + select { + case call := <-observer2.sentCalled: + if call.amount != nil { + t.Fatalf("observer called with wrong amount. got %d, want nil", call.amount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err == nil { + t.Fatalf("observer called without error. got nil want err") + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + // attempt multiple seconds later with debt over time based allowance + + debt = 9 * testRefreshRate + amount = big.NewInt(debt) + + payer.SetTime(int64(10010)) + recipient.SetTime(int64(10010)) + + observer.setPeerDebt(peerID, amount) + observer2.setPeerShadowBalance(peerID, amount) + + payer.Pay(context.Background(), peerID, amount) + + sentSum = sentSum.Add(sentSum, big.NewInt(6*testRefreshRate)) + + records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 5 { + t.Fatalf("got %v records, want %v", l, 5) + } + + record = records[4] + + if err := record.Err(); err != nil { + t.Fatalf("record error: %v", err) + } + + messages, err = protobuf.ReadMessages( + bytes.NewReader(record.In()), + func() protobuf.Message { return new(pb.Payment) }, + ) + if err != nil { + t.Fatal(err) + } + + receivedMessages, err = protobuf.ReadMessages( + bytes.NewReader(record.Out()), + func() protobuf.Message { return new(pb.PaymentAck) }, + ) + if err != nil { + t.Fatal(err) + } + + if len(messages) != 1 || len(receivedMessages) != 1 { + t.Fatalf("got %v/%v messages, want %v/%v", len(messages), len(receivedMessages), 1, 1) + } + + testAmount := big.NewInt(6 * testRefreshRate) + + sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount) + receivedAmount = big.NewInt(0).SetBytes(receivedMessages[0].(*pb.PaymentAck).Amount) + if sentAmount.Cmp(amount) != 0 { + t.Fatalf("got message with amount %v, want %v", sentAmount, amount) + } + + if receivedAmount.Cmp(testAmount) != 0 { + t.Fatalf("wrong settlement amount, got %v, want %v", receivedAmount, testAmount) + } + + select { + case call := <-observer.receivedCalled: + if call.amount.Cmp(testAmount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testAmount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + select { + case call := <-observer2.sentCalled: + if call.amount.Cmp(testAmount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testAmount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err != nil { + t.Fatalf("observer called with error. got %v want nil", call.err) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + totalSent, err = payer.TotalSent(peerID) + if err != nil { + t.Fatal(err) + } + + if totalSent.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum) + } + + totalReceived, err = recipient.TotalReceived(peerID) + if err != nil { + t.Fatal(err) + } + + if totalReceived.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum) + } + + // attempt further settlement with less outstanding debt than time allowance would allow + + debt = 5 * testRefreshRate + amount = big.NewInt(debt) + + payer.SetTime(int64(10020)) + recipient.SetTime(int64(10020)) + + observer.setPeerDebt(peerID, amount) + observer2.setPeerShadowBalance(peerID, amount) + + payer.Pay(context.Background(), peerID, amount) + + sentSum = sentSum.Add(sentSum, big.NewInt(5*testRefreshRate)) + + records, err = recorder.Records(peerID, "pseudosettle", "1.0.0", "pseudosettle") + if err != nil { + t.Fatal(err) + } + + if l := len(records); l != 6 { + t.Fatalf("got %v records, want %v", l, 5) + } + + record = records[5] + + if err := record.Err(); err != nil { + t.Fatalf("record error: %v", err) + } + + messages, err = protobuf.ReadMessages( + bytes.NewReader(record.In()), + func() protobuf.Message { return new(pb.Payment) }, + ) + if err != nil { + t.Fatal(err) + } + + if len(messages) != 1 { + t.Fatalf("got %v messages, want %v", len(messages), 1) + } + + testAmount = big.NewInt(5 * testRefreshRate) + + sentAmount = big.NewInt(0).SetBytes(messages[0].(*pb.Payment).Amount) + if sentAmount.Cmp(testAmount) != 0 { + t.Fatalf("got message with amount %v, want %v", sentAmount, testAmount) + } + + select { + case call := <-observer.receivedCalled: + if call.amount.Cmp(testAmount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testAmount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + select { + case call := <-observer2.sentCalled: + if call.amount.Cmp(testAmount) != 0 { + t.Fatalf("observer called with wrong amount. got %d, want %d", call.amount, testAmount) + } + + if !call.peer.Equal(peerID) { + t.Fatalf("observer called with wrong peer. got %v, want %v", call.peer, peerID) + } + if call.err != nil { + t.Fatalf("observer called with error. got %v want nil", call.err) + } + + case <-time.After(time.Second): + t.Fatal("expected observer to be called") + } + + totalSent, err = payer.TotalSent(peerID) + if err != nil { + t.Fatal(err) + } + + if totalSent.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalSent. got %d, want %d", totalSent, sentSum) + } + + totalReceived, err = recipient.TotalReceived(peerID) + if err != nil { + t.Fatal(err) + } + + if totalReceived.Cmp(sentSum) != 0 { + t.Fatalf("stored wrong totalReceived. got %d, want %d", totalReceived, sentSum) + } } diff --git a/pkg/settlement/swap/swap.go b/pkg/settlement/swap/swap.go index 20946380cfa..12d4bb4dd0a 100644 --- a/pkg/settlement/swap/swap.go +++ b/pkg/settlement/swap/swap.go @@ -101,7 +101,8 @@ func (s *Service) ReceiveCheque(ctx context.Context, peer swarm.Address, cheque } } - s.metrics.TotalReceived.Add(float64(amount.Uint64())) + tot, _ := big.NewFloat(0).SetInt(amount).Float64() + s.metrics.TotalReceived.Add(tot) s.metrics.ChequesReceived.Inc() return s.accountingAPI.NotifyPaymentReceived(peer, amount) @@ -112,7 +113,7 @@ func (s *Service) Pay(ctx context.Context, peer swarm.Address, amount *big.Int) var err error defer func() { if err != nil { - s.accountingAPI.NotifyPaymentSent(peer, nil, err) + s.accountingAPI.NotifyPaymentSent(peer, amount, err) } }() beneficiary, known, err := s.addressbook.Beneficiary(peer) diff --git a/pkg/settlement/swap/swap_test.go b/pkg/settlement/swap/swap_test.go index e0fd091d12f..dcd2cb6a114 100644 --- a/pkg/settlement/swap/swap_test.go +++ b/pkg/settlement/swap/swap_test.go @@ -62,6 +62,10 @@ func (t *testObserver) PeerDebt(peer swarm.Address) (*big.Int, error) { return nil, nil } +func (t *testObserver) ShadowBalance(peer swarm.Address) (*big.Int, error) { + return nil, nil +} + func (t *testObserver) NotifyPaymentReceived(peer swarm.Address, amount *big.Int) error { t.receivedCalled <- notifyPaymentReceivedCall{ peer: peer,