diff --git a/uacp/conn.go b/uacp/conn.go index f22930ce..09d4a196 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -9,7 +9,6 @@ import ( "io" "net" "sync/atomic" - "time" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" @@ -36,19 +35,19 @@ func nextid() uint32 { func Dial(ctx context.Context, endpoint string) (*Conn, error) { debug.Printf("Connect to %s", endpoint) - network, raddr, err := ResolveEndpoint(endpoint) + _, raddr, err := ResolveEndpoint(endpoint) if err != nil { return nil, err } var dialer net.Dialer - c, err := dialer.DialContext(ctx, network, raddr.String()) + c, err := dialer.DialContext(ctx, "tcp", raddr.String()) if err != nil { return nil, err } conn := &Conn{ - id: nextid(), - c: c, + TCPConn: c.(*net.TCPConn), + id: nextid(), ack: &Acknowledge{ ReceiveBufSize: DefaultReceiveBufSize, SendBufSize: DefaultSendBufSize, @@ -68,7 +67,7 @@ func Dial(ctx context.Context, endpoint string) (*Conn, error) { // Listener is a OPC UA Connection Protocol network listener. type Listener struct { - l net.Listener + l *net.TCPListener ack *Acknowledge endpoint string } @@ -94,7 +93,7 @@ func Listen(endpoint string, ack *Acknowledge) (*Listener, error) { if err != nil { return nil, err } - l, err := net.Listen(network, laddr.String()) + l, err := net.ListenTCP(network, laddr) if err != nil { return nil, err } @@ -110,11 +109,11 @@ func Listen(endpoint string, ack *Acknowledge) (*Listener, error) { // The first param ctx is to be passed to monitor(), which monitors and handles // incoming messages automatically in another goroutine. func (l *Listener) Accept(ctx context.Context) (*Conn, error) { - c, err := l.l.Accept() + c, err := l.l.AcceptTCP() if err != nil { return nil, err } - conn := &Conn{nextid(), c, l.ack} + conn := &Conn{c, nextid(), l.ack} if err := conn.srvhandshake(l.endpoint); err != nil { c.Close() return nil, err @@ -138,8 +137,8 @@ func (l *Listener) Endpoint() string { } type Conn struct { + *net.TCPConn id uint32 - c net.Conn ack *Acknowledge } @@ -165,35 +164,7 @@ func (c *Conn) MaxChunkCount() uint32 { func (c *Conn) Close() error { debug.Printf("conn %d: close", c.id) - return c.c.Close() -} - -func (c *Conn) Read(b []byte) (int, error) { - return c.c.Read(b) -} - -func (c *Conn) Write(b []byte) (int, error) { - return c.c.Write(b) -} - -func (c *Conn) SetDeadline(t time.Time) error { - return c.c.SetDeadline(t) -} - -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.c.SetReadDeadline(t) -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.c.SetWriteDeadline(t) -} - -func (c *Conn) LocalAddr() net.Addr { - return c.c.LocalAddr() -} - -func (c *Conn) RemoteAddr() net.Addr { - return c.c.RemoteAddr() + return c.TCPConn.Close() } func (c *Conn) handshake(endpoint string) error { @@ -290,12 +261,13 @@ func (c *Conn) srvhandshake(endpoint string) error { return errors.Errorf("uacp: invalid endpoint url %s", rhe.EndpointURL) } debug.Printf("conn %d: connecting to %s", c.id, rhe.ServerURI) - c.c.Close() - c, err := Dial(context.Background(), rhe.ServerURI) + c.Close() + var dialer net.Dialer + c2, err := dialer.DialContext(context.Background(), "tcp", rhe.ServerURI) if err != nil { return err } - c.c = c + c.TCPConn = c2.(*net.TCPConn) debug.Printf("conn %d: recv %#v", c.id, rhe) return nil @@ -322,7 +294,7 @@ const hdrlen = 8 func (c *Conn) Receive() ([]byte, error) { b := make([]byte, c.ack.ReceiveBufSize) - if _, err := io.ReadFull(c.c, b[:hdrlen]); err != nil { + if _, err := io.ReadFull(c, b[:hdrlen]); err != nil { // todo(fs): do not wrap this error since it hides io.EOF // todo(fs): use golang.org/x/xerrors return nil, err @@ -337,7 +309,7 @@ func (c *Conn) Receive() ([]byte, error) { return nil, errors.Errorf("uacp: message too large: %d > %d bytes", h.MessageSize, c.ack.ReceiveBufSize) } - if _, err := io.ReadFull(c.c, b[hdrlen:h.MessageSize]); err != nil { + if _, err := io.ReadFull(c, b[hdrlen:h.MessageSize]); err != nil { // todo(fs): do not wrap this error since it hides io.EOF // todo(fs): use golang.org/x/xerrors return nil, err @@ -381,7 +353,7 @@ func (c *Conn) Send(typ string, msg interface{}) error { } b := append(hdr, body...) - if _, err := c.c.Write(b); err != nil { + if _, err := c.Write(b); err != nil { return errors.Errorf("write failed: %s", err) } debug.Printf("conn %d: sent %s with %d bytes", c.id, typ, len(b))