Skip to content

Commit

Permalink
swap net.Conn for net.TCPConn and embed
Browse files Browse the repository at this point in the history
  • Loading branch information
kung-foo committed Apr 1, 2020
1 parent 16e629f commit 6f756f3
Showing 1 changed file with 17 additions and 45 deletions.
62 changes: 17 additions & 45 deletions uacp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"net"
"sync/atomic"
"time"

"github.com/gopcua/opcua/debug"
"github.com/gopcua/opcua/errors"
Expand All @@ -36,19 +35,19 @@ func nextid() uint32 {

func Dial(ctx context.Context, endpoint string) (*Conn, error) {
debug.Printf("Connect to %s", endpoint)
network, raddr, err := ResolveEndpoint(endpoint)
_, raddr, err := ResolveEndpoint(endpoint)
if err != nil {
return nil, err
}
var dialer net.Dialer
c, err := dialer.DialContext(ctx, network, raddr.String())
c, err := dialer.DialContext(ctx, "tcp", raddr.String())
if err != nil {
return nil, err
}

conn := &Conn{
id: nextid(),
c: c,
TCPConn: c.(*net.TCPConn),
id: nextid(),
ack: &Acknowledge{
ReceiveBufSize: DefaultReceiveBufSize,
SendBufSize: DefaultSendBufSize,
Expand All @@ -68,7 +67,7 @@ func Dial(ctx context.Context, endpoint string) (*Conn, error) {

// Listener is a OPC UA Connection Protocol network listener.
type Listener struct {
l net.Listener
l *net.TCPListener
ack *Acknowledge
endpoint string
}
Expand All @@ -94,7 +93,7 @@ func Listen(endpoint string, ack *Acknowledge) (*Listener, error) {
if err != nil {
return nil, err
}
l, err := net.Listen(network, laddr.String())
l, err := net.ListenTCP(network, laddr)
if err != nil {
return nil, err
}
Expand All @@ -110,11 +109,11 @@ func Listen(endpoint string, ack *Acknowledge) (*Listener, error) {
// The first param ctx is to be passed to monitor(), which monitors and handles
// incoming messages automatically in another goroutine.
func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
c, err := l.l.Accept()
c, err := l.l.AcceptTCP()
if err != nil {
return nil, err
}
conn := &Conn{nextid(), c, l.ack}
conn := &Conn{c, nextid(), l.ack}
if err := conn.srvhandshake(l.endpoint); err != nil {
c.Close()
return nil, err
Expand All @@ -138,8 +137,8 @@ func (l *Listener) Endpoint() string {
}

type Conn struct {
*net.TCPConn
id uint32
c net.Conn
ack *Acknowledge
}

Expand All @@ -165,35 +164,7 @@ func (c *Conn) MaxChunkCount() uint32 {

func (c *Conn) Close() error {
debug.Printf("conn %d: close", c.id)
return c.c.Close()
}

func (c *Conn) Read(b []byte) (int, error) {
return c.c.Read(b)
}

func (c *Conn) Write(b []byte) (int, error) {
return c.c.Write(b)
}

func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t)
}

func (c *Conn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}

func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t)
}

func (c *Conn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}

func (c *Conn) RemoteAddr() net.Addr {
return c.c.RemoteAddr()
return c.TCPConn.Close()
}

func (c *Conn) handshake(endpoint string) error {
Expand Down Expand Up @@ -290,12 +261,13 @@ func (c *Conn) srvhandshake(endpoint string) error {
return errors.Errorf("uacp: invalid endpoint url %s", rhe.EndpointURL)
}
debug.Printf("conn %d: connecting to %s", c.id, rhe.ServerURI)
c.c.Close()
c, err := Dial(context.Background(), rhe.ServerURI)
c.Close()
var dialer net.Dialer
c2, err := dialer.DialContext(context.Background(), "tcp", rhe.ServerURI)
if err != nil {
return err
}
c.c = c
c.TCPConn = c2.(*net.TCPConn)
debug.Printf("conn %d: recv %#v", c.id, rhe)
return nil

Expand All @@ -322,7 +294,7 @@ const hdrlen = 8
func (c *Conn) Receive() ([]byte, error) {
b := make([]byte, c.ack.ReceiveBufSize)

if _, err := io.ReadFull(c.c, b[:hdrlen]); err != nil {
if _, err := io.ReadFull(c, b[:hdrlen]); err != nil {
// todo(fs): do not wrap this error since it hides io.EOF
// todo(fs): use golang.org/x/xerrors
return nil, err
Expand All @@ -337,7 +309,7 @@ func (c *Conn) Receive() ([]byte, error) {
return nil, errors.Errorf("uacp: message too large: %d > %d bytes", h.MessageSize, c.ack.ReceiveBufSize)
}

if _, err := io.ReadFull(c.c, b[hdrlen:h.MessageSize]); err != nil {
if _, err := io.ReadFull(c, b[hdrlen:h.MessageSize]); err != nil {
// todo(fs): do not wrap this error since it hides io.EOF
// todo(fs): use golang.org/x/xerrors
return nil, err
Expand Down Expand Up @@ -381,7 +353,7 @@ func (c *Conn) Send(typ string, msg interface{}) error {
}

b := append(hdr, body...)
if _, err := c.c.Write(b); err != nil {
if _, err := c.Write(b); err != nil {
return errors.Errorf("write failed: %s", err)
}
debug.Printf("conn %d: sent %s with %d bytes", c.id, typ, len(b))
Expand Down

0 comments on commit 6f756f3

Please sign in to comment.