Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2p: fix decoding of disconnect reason #204

Merged
merged 5 commits into from
Dec 15, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions p2p/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package p2p
import (
"bytes"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"math/big"
"sync/atomic"

"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp"
Expand Down Expand Up @@ -153,3 +155,78 @@ func (r *postrack) ReadByte() (byte, error) {
}
return b, err
}

// MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter.
func MsgPipe() (*MsgPipeRW, *MsgPipeRW) {
var (
c1, c2 = make(chan Msg), make(chan Msg)
closing = make(chan struct{})
closed = new(int32)
rw1 = &MsgPipeRW{c1, c2, closing, closed}
rw2 = &MsgPipeRW{c2, c1, closing, closed}
)
return rw1, rw2
}

// ErrPipeClosed is returned from pipe operations after the
// pipe has been closed.
var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe")

// MsgPipeRW is an endpoint of a MsgReadWriter pipe.
type MsgPipeRW struct {
w chan<- Msg
r <-chan Msg
closing chan struct{}
closed *int32
}

// WriteMsg sends a messsage on the pipe.
// It blocks until the receiver has consumed the message payload.
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1)
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
select {
case p.w <- msg:
if msg.Size > 0 {
// wait for payload read or discard
<-consumed
}
return nil
case <-p.closing:
}
}
return ErrPipeClosed
}

// EncodeMsg is a convenient shorthand for sending an RLP-encoded message.
func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error {
return p.WriteMsg(NewMsg(code, data...))
}

// ReadMsg returns a message sent on the other end of the pipe.
func (p *MsgPipeRW) ReadMsg() (Msg, error) {
if atomic.LoadInt32(p.closed) == 0 {
select {
case msg := <-p.r:
return msg, nil
case <-p.closing:
}
}
return Msg{}, ErrPipeClosed
}

// Close unblocks any pending ReadMsg and WriteMsg calls on both ends
// of the pipe. They will return ErrPipeClosed. Note that Close does
// not interrupt any reads from a message payload.
func (p *MsgPipeRW) Close() error {
if atomic.AddInt32(p.closed, 1) != 1 {
// someone else is already closing
atomic.StoreInt32(p.closed, 1) // avoid overflow
return nil
}
close(p.closing)
return nil
}
63 changes: 63 additions & 0 deletions p2p/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package p2p

import (
"bytes"
"fmt"
"io/ioutil"
"runtime"
"testing"
"time"

"github.com/ethereum/go-ethereum/ethutil"
)
Expand Down Expand Up @@ -68,3 +71,63 @@ func TestDecodeRealMsg(t *testing.T) {
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
}
}

func ExampleMsgPipe() {
rw1, rw2 := MsgPipe()
go func() {
rw1.EncodeMsg(8, []byte{0, 0})
rw1.EncodeMsg(5, []byte{1, 1})
rw1.Close()
}()

for {
msg, err := rw2.ReadMsg()
if err != nil {
break
}
var data [1][]byte
msg.Decode(&data)
fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
}
// Output:
// msg: 8, 0000
// msg: 5, 0101
}

func TestMsgPipeUnblockWrite(t *testing.T) {
loop:
for i := 0; i < 100; i++ {
rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := rw1.EncodeMsg(1); err == nil {
t.Error("EncodeMsg returned nil error")
} else if err != ErrPipeClosed {
t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)
}
close(done)
}()

// this call should ensure that EncodeMsg is waiting to
// deliver sometimes. if this isn't done, Close is likely to
// be executed before EncodeMsg starts and then we won't test
// all the cases.
runtime.Gosched()

rw2.Close()
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Errorf("write didn't unblock")
break loop
}
}
}

// This test should panic if concurrent close isn't implemented correctly.
func TestMsgPipeConcurrentClose(t *testing.T) {
rw1, _ := MsgPipe()
for i := 0; i < 10; i++ {
go rw1.Close()
}
}
17 changes: 12 additions & 5 deletions p2p/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error)
proto.in <- msg
} else {
wait = true
pr := &eofSignal{msg.Payload, protoDone}
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
msg.Payload = pr
proto.in <- msg
}
Expand Down Expand Up @@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) {
return msg, nil
}

// eofSignal wraps a reader with eof signaling.
// the eof channel is closed when the wrapped reader
// reaches EOF.
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
//
type eofSignal struct {
wrapped io.Reader
count int64
eof chan<- struct{}
}

// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.

func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
if err != nil {
r.count -= int64(n)
if (err != nil || r.count <= 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}
9 changes: 9 additions & 0 deletions p2p/peer_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,16 @@ func (d DiscReason) String() string {
return discReasonToString[d]
}

type discRequestedError DiscReason

func (err discRequestedError) Error() string {
return fmt.Sprintf("disconnect requested: %v", DiscReason(err))
}

func discReasonForError(err error) DiscReason {
if reason, ok := err.(discRequestedError); ok {
return DiscReason(reason)
}
peerError, ok := err.(*peerError)
if !ok {
return DiscSubprotocolError
Expand Down
56 changes: 56 additions & 0 deletions p2p/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/hex"
"io"
"io/ioutil"
"net"
"reflect"
Expand Down Expand Up @@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) {
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
}

func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)

// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}

// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 8 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}

// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}

// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}
5 changes: 2 additions & 3 deletions p2p/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,11 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
return newPeerError(errProtocolBreach, "extra handshake received")

case discMsg:
var reason DiscReason
var reason [1]DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
bp.peer.Disconnect(reason)
return nil
return discRequestedError(reason[0])

case pingMsg:
return bp.rw.EncodeMsg(pongMsg)
Expand Down
58 changes: 58 additions & 0 deletions p2p/protocol_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package p2p

import (
"fmt"
"testing"
)

func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil)
peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar")
peer.pubkeyHook = func(*peerAddr) error { return nil }

rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := expectMsg(rw2, handshakeMsg); err != nil {
t.Error(err)
}
err := rw2.EncodeMsg(handshakeMsg,
baseProtocolVersion,
"",
[]interface{}{},
0,
make([]byte, 64),
)
if err != nil {
t.Error(err)
}
if err := expectMsg(rw2, getPeersMsg); err != nil {
t.Error(err)
}
if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil {
t.Error(err)
}
close(done)
}()

if err := runBaseProtocol(peer, rw1); err == nil {
t.Errorf("base protocol returned without error")
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
t.Errorf("base protocol returned wrong error: %v", err)
}
<-done
}

func expectMsg(r MsgReader, code uint64) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if err := msg.Discard(); err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
}
return nil
}