diff --git a/swap/protocol.go b/swap/protocol.go index f6e3a4df35..2e9642e23a 100644 --- a/swap/protocol.go +++ b/swap/protocol.go @@ -34,6 +34,9 @@ var ( // ErrEmptyAddressInSignature is used when the empty address is used for the chequebook in the handshake ErrEmptyAddressInSignature = errors.New("empty address in handshake") + // ErrDifferentChainID is used when the chain id exchanged during the handshake does not match + ErrDifferentChainID = errors.New("different chain id") + // ErrInvalidHandshakeMsg is used when the message received during handshake does not conform to the // structure of the HandshakeMsg ErrInvalidHandshakeMsg = errors.New("invalid handshake message") @@ -87,7 +90,7 @@ func (s *Swap) Stop() error { return s.Close() } -// verifyHandshake verifies the chequebook address transmitted in the swap handshake +// verifyHandshake verifies the chequebook address and chain id transmitted in the swap handshake func (s *Swap) verifyHandshake(msg interface{}) error { handshake, ok := msg.(*HandshakeMsg) if !ok { @@ -98,6 +101,10 @@ func (s *Swap) verifyHandshake(msg interface{}) error { return ErrEmptyAddressInSignature } + if handshake.ChainID != s.chainID { + return ErrDifferentChainID + } + return s.chequebookFactory.VerifyContract(handshake.ContractAddress) } @@ -107,6 +114,7 @@ func (s *Swap) run(p *p2p.Peer, rw p2p.MsgReadWriter) error { handshake, err := protoPeer.Handshake(context.Background(), &HandshakeMsg{ ContractAddress: s.GetParams().ContractAddress, + ChainID: s.chainID, }, s.verifyHandshake) if err != nil { return err diff --git a/swap/protocol_test.go b/swap/protocol_test.go index 41bc6ea51d..2fd708ab54 100644 --- a/swap/protocol_test.go +++ b/swap/protocol_test.go @@ -44,68 +44,174 @@ func init() { log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true)))) } -/* -TestHandshake creates two mock nodes and initiates an exchange; -it expects a handshake to take place between the two nodes -(the handshake would fail because we don't actually use real nodes here) -*/ -func TestHandshake(t *testing.T) { - var err error +// protocol tester based on a swap instance +type swapTester struct { + *p2ptest.ProtocolTester + swap *Swap +} - // setup test swap object +// creates a new protocol tester for swap with a deployed chequebook +func newSwapTester(t *testing.T) (*swapTester, func(), error) { swap, clean := newTestSwap(t, ownerKey, nil) - defer clean() - ctx := context.Background() - err = testDeploy(ctx, swap) + err := testDeploy(context.Background(), swap) if err != nil { - t.Fatal(err) + return nil, nil, err } + // setup the protocolTester, which will allow protocol testing by sending messages - protocolTester := p2ptest.NewProtocolTester(swap.owner.privateKey, 2, swap.run) + protocolTester := p2ptest.NewProtocolTester(swap.owner.privateKey, 1, swap.run) + return &swapTester{ + ProtocolTester: protocolTester, + swap: swap, + }, clean, nil +} + +// creates a test exchange for the handshakes +func HandshakeMsgExchange(lhs, rhs *HandshakeMsg, id enode.ID) []p2ptest.Exchange { + return []p2ptest.Exchange{ + { + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: lhs, + Peer: id, + }, + }, + }, + { + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: rhs, + Peer: id, + }, + }, + }, + } +} - // shortcut to creditor node - debitor := protocolTester.Nodes[0] - creditor := protocolTester.Nodes[1] +// helper function for testing the handshake +// lhs is the HandshakeMsg we expect to be sent, rhs the one we receive +// disconnects is a list of disconnect events to be expected +func (s *swapTester) testHandshake(lhs, rhs *HandshakeMsg, disconnects ...*p2ptest.Disconnect) error { + if err := s.TestExchanges(HandshakeMsgExchange(lhs, rhs, s.Nodes[0].ID())...); err != nil { + return err + } - // set balance artifially - swap.saveBalance(creditor.ID(), -42) + if len(disconnects) > 0 { + return s.TestDisconnected(disconnects...) + } - // create the expected cheque to be received - cheque := newTestCheque() + // If we don't expect disconnect, ensure peers remain connected + err := s.TestDisconnected(&p2ptest.Disconnect{ + Peer: s.Nodes[0].ID(), + Error: nil, + }) - // sign the cheque - cheque.Signature, err = cheque.Sign(swap.owner.privateKey) + if err == nil { + return fmt.Errorf("Unexpected peer disconnect") + } + + if err.Error() != "timed out waiting for peers to disconnect" { + return err + } + + return nil +} + +// creates a new HandshakeMsg +func newSwapHandshakeMsg(contractAddress common.Address, chainID uint64) *HandshakeMsg { + return &HandshakeMsg{ + ContractAddress: contractAddress, + ChainID: chainID, + } +} + +// creates the correct HandshakeMsg based on Swap instance +func correctSwapHandshakeMsg(swap *Swap) *HandshakeMsg { + return newSwapHandshakeMsg(swap.GetParams().ContractAddress, swap.chainID) +} + +// TestHandshake tests the correct handshake scenario +func TestHandshake(t *testing.T) { + // setup the protocolTester, which will allow protocol testing by sending messages + protocolTester, clean, err := newSwapTester(t) + defer clean() if err != nil { t.Fatal(err) } - // run the exchange: - // trigger a `EmitChequeMsg` - // expect HandshakeMsg on each node - err = protocolTester.TestExchanges(p2ptest.Exchange{ - Label: "TestHandshake", - Triggers: []p2ptest.Trigger{ - { - Code: 0, - Msg: &HandshakeMsg{ - ContractAddress: swap.GetParams().ContractAddress, - }, - Peer: creditor.ID(), - }, + err = protocolTester.testHandshake( + correctSwapHandshakeMsg(protocolTester.swap), + correctSwapHandshakeMsg(protocolTester.swap), + ) + if err != nil { + t.Fatal(err) + } +} + +// TestHandshakeInvalidChainID tests that a handshake with the wrong chain id is rejected +func TestHandshakeInvalidChainID(t *testing.T) { + // setup the protocolTester, which will allow protocol testing by sending messages + protocolTester, clean, err := newSwapTester(t) + defer clean() + if err != nil { + t.Fatal(err) + } + + err = protocolTester.testHandshake( + correctSwapHandshakeMsg(protocolTester.swap), + newSwapHandshakeMsg(protocolTester.swap.GetParams().ContractAddress, 1234), + &p2ptest.Disconnect{ + Peer: protocolTester.Nodes[0].ID(), + Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): %v", ErrDifferentChainID), }, - Expects: []p2ptest.Expect{ - { - Code: 0, - Msg: &HandshakeMsg{ - ContractAddress: swap.GetParams().ContractAddress, - }, - Peer: debitor.ID(), - }, + ) + if err != nil { + t.Fatal(err) + } +} + +// TestHandshakeEmptyContract tests that a handshake with an empty contract address is rejected +func TestHandshakeEmptyContract(t *testing.T) { + // setup the protocolTester, which will allow protocol testing by sending messages + protocolTester, clean, err := newSwapTester(t) + defer clean() + if err != nil { + t.Fatal(err) + } + + err = protocolTester.testHandshake( + correctSwapHandshakeMsg(protocolTester.swap), + newSwapHandshakeMsg(common.Address{}, 1234), + &p2ptest.Disconnect{ + Peer: protocolTester.Nodes[0].ID(), + Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): %v", ErrEmptyAddressInSignature), }, - }) + ) + if err != nil { + t.Fatal(err) + } +} - // there should be no error at this point +// TestHandshakeInvalidContract tests that a handshake with an address that's not a valid chequebook +func TestHandshakeInvalidContract(t *testing.T) { + // setup the protocolTester, which will allow protocol testing by sending messages + protocolTester, clean, err := newSwapTester(t) + defer clean() + if err != nil { + t.Fatal(err) + } + + err = protocolTester.testHandshake( + correctSwapHandshakeMsg(protocolTester.swap), + newSwapHandshakeMsg(ownerAddress, protocolTester.swap.chainID), + &p2ptest.Disconnect{ + Peer: protocolTester.Nodes[0].ID(), + Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): %v", contract.ErrNotDeployedByFactory), + }, + ) if err != nil { t.Fatal(err) } diff --git a/swap/simulations_test.go b/swap/simulations_test.go index 901daa783c..7c9eeb4c7f 100644 --- a/swap/simulations_test.go +++ b/swap/simulations_test.go @@ -231,7 +231,7 @@ func newSharedBackendSwaps(t *testing.T, nodeCount int) (*swapSimulationParams, if err != nil { t.Fatal(err) } - params.swaps[i] = newSwapInstance(stores[i], owner, testBackend, defParams, factory) + params.swaps[i] = newSwapInstance(stores[i], owner, testBackend, 10, defParams, factory) } params.backend = testBackend diff --git a/swap/swap.go b/swap/swap.go index fd79666965..3041f1a7c7 100644 --- a/swap/swap.go +++ b/swap/swap.go @@ -56,8 +56,9 @@ type Swap struct { store state.Store // store is needed in order to keep balances and cheques across sessions peers map[enode.ID]*Peer // map of all swap Peers peersLock sync.RWMutex // lock for peers map - backend contract.Backend // the backend (blockchain) used owner *Owner // contract access + backend contract.Backend // the backend (blockchain) used + chainID uint64 // id of the chain the backend is connected to params *Params // economic and operational parameters contract contract.Contract // reference to the smart contract chequebookFactory contract.SimpleSwapFactory // the chequebook factory used @@ -130,7 +131,7 @@ func swapRotatingFileHandler(logdir string) (log.Handler, error) { } // newSwapInstance is a swap constructor function without integrity checks -func newSwapInstance(stateStore state.Store, owner *Owner, backend contract.Backend, params *Params, chequebookFactory contract.SimpleSwapFactory) *Swap { +func newSwapInstance(stateStore state.Store, owner *Owner, backend contract.Backend, chainID uint64, params *Params, chequebookFactory contract.SimpleSwapFactory) *Swap { return &Swap{ store: stateStore, peers: make(map[enode.ID]*Peer), @@ -139,6 +140,7 @@ func newSwapInstance(stateStore state.Store, owner *Owner, backend contract.Back params: params, chequebookFactory: chequebookFactory, honeyPriceOracle: NewHoneyPriceOracle(), + chainID: chainID, } } @@ -193,6 +195,7 @@ func New(dbPath string, prvkey *ecdsa.PrivateKey, backendURL string, params *Par stateStore, owner, backend, + chainID.Uint64(), params, factory, ) diff --git a/swap/swap_test.go b/swap/swap_test.go index 2ca8ec6c2e..52c018a5ba 100644 --- a/swap/swap_test.go +++ b/swap/swap_test.go @@ -1371,7 +1371,7 @@ func newBaseTestSwapWithParams(t *testing.T, key *ecdsa.PrivateKey, params *Para if err != nil { t.Fatal(err) } - swap := newSwapInstance(stateStore, owner, backend, params, factory) + swap := newSwapInstance(stateStore, owner, backend, 10, params, factory) return swap, dir } diff --git a/swap/types.go b/swap/types.go index 246d4cf547..e6fcb24ed7 100644 --- a/swap/types.go +++ b/swap/types.go @@ -36,7 +36,8 @@ type Cheque struct { // HandshakeMsg is exchanged on peer handshake type HandshakeMsg struct { - ContractAddress common.Address + ChainID uint64 // chain id of the blockchain the peer is connected to + ContractAddress common.Address // chequebook contract address of the peer } // EmitChequeMsg is sent from the debitor to the creditor with the actual cheque