Skip to content

Commit

Permalink
Merge pull request #18 from libp2p/fix/reading-addrs
Browse files Browse the repository at this point in the history
fix reading from conn and addresses
  • Loading branch information
whyrusleeping authored Sep 5, 2017
2 parents 35370ce + cb1cdda commit e57234e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
53 changes: 47 additions & 6 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package websocket

import (
"io"
"net"
"time"

Expand All @@ -14,15 +15,54 @@ type Conn struct {
*ws.Conn
DefaultMessageType int
done func()
reader io.Reader
}

func (c *Conn) Read(b []byte) (n int, err error) {
_, r, err := c.Conn.NextReader()
func (c *Conn) Read(b []byte) (int, error) {
if c.reader == nil {
if err := c.prepNextReader(); err != nil {
return 0, err
}
}

for {
n, err := c.reader.Read(b)
switch err {
case io.EOF:
c.reader = nil

if n > 0 {
return n, nil
}

if err := c.prepNextReader(); err != nil {
return 0, err
}

// explicitly looping
default:
return n, err
}
}
}

func (c *Conn) prepNextReader() error {
t, r, err := c.Conn.NextReader()
if err != nil {
return 0, err
if wserr, ok := err.(*ws.CloseError); ok {
if wserr.Code == 1000 || wserr.Code == 1005 {
return io.EOF
}
}
return err
}

if t == ws.CloseMessage {
return io.EOF
}

return r.Read(b)
c.reader = r
return nil
}

func (c *Conn) Write(b []byte) (n int, err error) {
Expand All @@ -38,15 +78,16 @@ func (c *Conn) Close() error {
c.done()
}

c.Conn.WriteMessage(ws.CloseMessage, nil)
return c.Conn.Close()
}

func (c *Conn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
return NewAddr(c.Conn.LocalAddr().String())
}

func (c *Conn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
return NewAddr(c.Conn.RemoteAddr().String())
}

func (c *Conn) SetDeadline(t time.Time) error {
Expand Down
11 changes: 7 additions & 4 deletions p2p/transport/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package websocket

import (
"bytes"
"io/ioutil"
"testing"
"testing/iotest"

ma "github.com/multiformats/go-multiaddr"
)
Expand Down Expand Up @@ -40,13 +42,14 @@ func TestWebsocketListen(t *testing.T) {
}
defer c.Close()

buf := make([]byte, 32)
n, err := c.Read(buf)
obr := iotest.OneByteReader(c)

out, err := ioutil.ReadAll(obr)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(buf[:n], msg) {
t.Fatal("got wrong message", buf[:n], msg)
if !bytes.Equal(out, msg) {
t.Fatal("got wrong message", out, msg)
}
}

0 comments on commit e57234e

Please sign in to comment.