diff --git a/connection.go b/connection.go index 5e9d691..725f2cf 100644 --- a/connection.go +++ b/connection.go @@ -32,7 +32,7 @@ type connConnection struct { } type chanConnection struct { - sendConn *net.UDPConn + sendConn net.PacketConn channel chan []byte addr *net.UDPAddr timeout time.Duration @@ -40,7 +40,7 @@ type chanConnection struct { } func (c *chanConnection) sendTo(data []byte, addr *net.UDPAddr) error { - _, err := c.sendConn.WriteToUDP(data, addr) + _, err := c.sendConn.WriteTo(data, addr) return err } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f5cbc82 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/pin/tftp + +go 1.13 + +require ( + github.com/stretchr/testify v1.4.0 + golang.org/x/net v0.0.0-20200202094626-16171245cfb2 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4d1a89f --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/server.go b/server.go index d2a4d62..89991fd 100644 --- a/server.go +++ b/server.go @@ -46,7 +46,7 @@ type Server struct { writeHandler func(filename string, wt io.WriterTo) error hook Hook backoff backoffFunc - conn *net.UDPConn + conn net.PacketConn conn6 *ipv6.PacketConn conn4 *ipv4.PacketConn quit chan chan struct{} @@ -188,7 +188,7 @@ func (s *Server) ListenAndServe(addr string) error { // useful for the case when you want to run server in separate goroutine // but still want to be able to handle any errors opening connection. // Serve returns when Shutdown is called or connection is closed. -func (s *Server) Serve(conn *net.UDPConn) error { +func (s *Server) Serve(conn net.PacketConn) error { defer conn.Close() laddr := conn.LocalAddr() host, _, err := net.SplitHostPort(laddr.String()) @@ -202,15 +202,18 @@ func (s *Server) Serve(conn *net.UDPConn) error { if addr == nil { return fmt.Errorf("Failed to determine IP class of listening address") } - if addr.To4() != nil { - s.conn4 = ipv4.NewPacketConn(conn) - if err := s.conn4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { - s.conn4 = nil - } - } else { - s.conn6 = ipv6.NewPacketConn(conn) - if err := s.conn6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { - s.conn6 = nil + + if conn, ok := conn.(*net.UDPConn); ok { + if addr.To4() != nil { + s.conn4 = ipv4.NewPacketConn(conn) + if err := s.conn4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { + s.conn4 = nil + } + } else { + s.conn6 = ipv6.NewPacketConn(conn) + if err := s.conn6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { + s.conn6 = nil + } } } @@ -290,11 +293,11 @@ func (s *Server) processRequest6() error { // Fallback if we had problems opening a ipv4/6 control channel func (s *Server) processRequest() error { buf := make([]byte, datagramLength) - cnt, srcAddr, err := s.conn.ReadFromUDP(buf) + cnt, srcAddr, err := s.conn.ReadFrom(buf) if err != nil { return fmt.Errorf("reading UDP: %v", err) } - return s.handlePacket(nil, srcAddr, buf, cnt, blockLength, nil) + return s.handlePacket(nil, srcAddr.(*net.UDPAddr), buf, cnt, blockLength, nil) } // Shutdown make server stop listening for new requests, allows diff --git a/single_port.go b/single_port.go index c4eeae3..048d6e9 100644 --- a/single_port.go +++ b/single_port.go @@ -104,11 +104,11 @@ func (s *Server) getPacket(buf []byte) (int, net.IP, *net.UDPAddr, int, error) { } return cnt, localAddr, srcAddr.(*net.UDPAddr), maxSz, nil } else { - cnt, srcAddr, err := s.conn.ReadFromUDP(buf) + cnt, srcAddr, err := s.conn.ReadFrom(buf) if err != nil { return 0, nil, nil, 0, err } - return cnt, nil, srcAddr, blockLength, nil + return cnt, nil, srcAddr.(*net.UDPAddr), blockLength, nil } }