diff --git a/conn_test.go b/conn_test.go index ef234bc..f7d7253 100644 --- a/conn_test.go +++ b/conn_test.go @@ -7,11 +7,13 @@ import ( "fmt" "io/ioutil" mrand "math/rand" + "net" "time" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" + filter "github.com/libp2p/go-maddr-filter" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo" @@ -62,12 +64,12 @@ var _ = Describe("Connection", func() { }) It("handshakes on IPv4", func() { - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) @@ -86,12 +88,12 @@ var _ = Describe("Connection", func() { }) It("handshakes on IPv6", func() { - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip6/::1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) @@ -110,12 +112,12 @@ var _ = Describe("Connection", func() { }) It("opens and accepts streams", func() { - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) @@ -139,11 +141,11 @@ var _ = Describe("Connection", func() { It("fails if the peer ID doesn't match", func() { thirdPartyID, _ := createPeer() - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) // dial, but expect the wrong peer ID _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) @@ -161,14 +163,47 @@ var _ = Describe("Connection", func() { Eventually(done).Should(BeClosed()) }) + It("filters addresses", func() { + filters := filter.NewFilters() + ipNet := net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + } + filters.AddFilter(ipNet, filter.ActionDeny) + testMA, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234/quic") + Expect(err).ToNot(HaveOccurred()) + Expect(filters.AddrBlocked(testMA)).To(BeTrue()) + + serverTransport, err := NewTransport(serverKey, nil, filters) + Expect(err).ToNot(HaveOccurred()) + ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, nil, nil) + Expect(err).ToNot(HaveOccurred()) + + // make sure that connection attempts fails + quicConfig.HandshakeTimeout = 250 * time.Millisecond + _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + + // now allow the address and make sure the connection goes through + quicConfig.HandshakeTimeout = 2 * time.Second + filters.AddFilter(ipNet, filter.ActionAccept) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + Expect(err).ToNot(HaveOccurred()) + conn.Close() + }) + It("dials to two servers at the same time", func() { serverID2, serverKey2 := createPeer() - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - serverTransport2, err := NewTransport(serverKey2, nil) defer ln1.Close() + serverTransport2, err := NewTransport(serverKey2, nil, nil) Expect(err).ToNot(HaveOccurred()) ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic") defer ln2.Close() @@ -194,7 +229,7 @@ var _ = Describe("Connection", func() { } }() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) diff --git a/filtered_conn.go b/filtered_conn.go new file mode 100644 index 0000000..dc60bb0 --- /dev/null +++ b/filtered_conn.go @@ -0,0 +1,34 @@ +package libp2pquic + +import ( + "net" + + filter "github.com/libp2p/go-maddr-filter" +) + +type filteredConn struct { + net.PacketConn + + filters *filter.Filters +} + +func newFilteredConn(c net.PacketConn, filters *filter.Filters) net.PacketConn { + return &filteredConn{PacketConn: c, filters: filters} +} + +func (c *filteredConn) ReadFrom(b []byte) (n int, addr net.Addr, rerr error) { + for { + n, addr, rerr = c.PacketConn.ReadFrom(b) + // Short Header packet, see https://tools.ietf.org/html/draft-ietf-quic-invariants-07#section-4.2. + if n < 1 || b[0]&0x80 == 0 { + return + } + maddr, err := toQuicMultiaddr(addr) + if err != nil { + panic(err) + } + if !c.filters.AddrBlocked(maddr) { + return + } + } +} diff --git a/go.mod b/go.mod index e06c1e6..e75c856 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/ipfs/go-log v1.0.2 github.com/libp2p/go-libp2p-core v0.5.0 github.com/libp2p/go-libp2p-tls v0.1.3 + github.com/libp2p/go-maddr-filter v0.0.5 github.com/lucas-clemente/quic-go v0.15.2 github.com/multiformats/go-multiaddr v0.2.1 github.com/multiformats/go-multiaddr-fmt v0.1.0 diff --git a/go.sum b/go.sum index f848e72..bc82d3a 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ github.com/libp2p/go-libp2p-core v0.5.0 h1:FBQ1fpq2Fo/ClyjojVJ5AKXlKhvNc/B6U0O+7 github.com/libp2p/go-libp2p-core v0.5.0/go.mod h1:49XGI+kc38oGVwqSBhDEwytaAxgZasHhFfQKibzTls0= github.com/libp2p/go-libp2p-tls v0.1.3 h1:twKMhMu44jQO+HgQK9X8NHO5HkeJu2QbhLzLJpa8oNM= github.com/libp2p/go-libp2p-tls v0.1.3/go.mod h1:wZfuewxOndz5RTnCAxFliGjvYSDA40sKitV4c50uI1M= +github.com/libp2p/go-maddr-filter v0.0.5 h1:CW3AgbMO6vUvT4kf87y4N+0P8KUl2aqLYhrGyDUbLSg= +github.com/libp2p/go-maddr-filter v0.0.5/go.mod h1:Jk+36PMfIqCJhAnaASRH83bdAvfDRp/w6ENFaC9bG+M= github.com/libp2p/go-openssl v0.0.4 h1:d27YZvLoTyMhIN4njrkr8zMDOM4lfpHIp6A+TK9fovg= github.com/libp2p/go-openssl v0.0.4/go.mod h1:unDrJpgy3oFr+rqXsarWifmJuNnJR4chtO1HmaZjggc= github.com/lucas-clemente/quic-go v0.15.2 h1:RgxRJ7rPde0Q/uXDeb3/UdblVvxrYGDAG9G9GO78LmI= @@ -166,6 +168,7 @@ github.com/mr-tron/base58 v1.1.3/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjW github.com/multiformats/go-base32 v0.0.3 h1:tw5+NhuwaOjJCC5Pp82QuXbrmLzWg7uxlMFp8Nq/kkI= github.com/multiformats/go-base32 v0.0.3/go.mod h1:pLiuGC8y0QR3Ue4Zug5UzK9LjgbkL8NSQj0zQ5Nz/AA= github.com/multiformats/go-base32 v0.0.3/go.mod h1:pLiuGC8y0QR3Ue4Zug5UzK9LjgbkL8NSQj0zQ5Nz/AA= +github.com/multiformats/go-multiaddr v0.0.1/go.mod h1:xKVEak1K9cS1VdmPZW3LSIb6lgmoS58qz/pzqmAxV44= github.com/multiformats/go-multiaddr v0.0.2/go.mod h1:xKVEak1K9cS1VdmPZW3LSIb6lgmoS58qz/pzqmAxV44= github.com/multiformats/go-multiaddr v0.1.1 h1:rVAztJYMhCQ7vEFr8FvxW3mS+HF2eY/oPbOMeS0ZDnE= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= diff --git a/libp2pquic_suite_test.go b/libp2pquic_suite_test.go index a2e1df4..d6ae151 100644 --- a/libp2pquic_suite_test.go +++ b/libp2pquic_suite_test.go @@ -5,12 +5,13 @@ import ( mrand "math/rand" "runtime/pprof" "strings" + "testing" "time" + "github.com/lucas-clemente/quic-go" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "testing" ) func TestLibp2pQuicTransport(t *testing.T) { @@ -22,8 +23,11 @@ var _ = BeforeSuite(func() { mrand.Seed(GinkgoRandomSeed()) }) -var garbageCollectIntervalOrig time.Duration -var maxUnusedDurationOrig time.Duration +var ( + garbageCollectIntervalOrig time.Duration + maxUnusedDurationOrig time.Duration + origQuicConfig *quic.Config +) func isGarbageCollectorRunning() bool { var b bytes.Buffer @@ -37,10 +41,13 @@ var _ = BeforeEach(func() { maxUnusedDurationOrig = maxUnusedDuration garbageCollectInterval = 50 * time.Millisecond maxUnusedDuration = 0 + origQuicConfig = quicConfig + quicConfig = quicConfig.Clone() }) var _ = AfterEach(func() { Eventually(isGarbageCollectorRunning).Should(BeFalse()) garbageCollectInterval = garbageCollectIntervalOrig maxUnusedDuration = maxUnusedDurationOrig + quicConfig = origQuicConfig }) diff --git a/listener_test.go b/listener_test.go index ac7716f..3388f1a 100644 --- a/listener_test.go +++ b/listener_test.go @@ -23,7 +23,7 @@ var _ = Describe("Listener", func() { Expect(err).ToNot(HaveOccurred()) key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) Expect(err).ToNot(HaveOccurred()) - t, err = NewTransport(key, nil) + t, err = NewTransport(key, nil, nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/reuse_base.go b/reuse_base.go index 347a9a6..053b777 100644 --- a/reuse_base.go +++ b/reuse_base.go @@ -4,6 +4,8 @@ import ( "net" "sync" "time" + + filter "github.com/libp2p/go-maddr-filter" ) // Constant. Defined as variables to simplify testing. @@ -20,7 +22,10 @@ type reuseConn struct { unusedSince time.Time } -func newReuseConn(conn net.PacketConn) *reuseConn { +func newReuseConn(conn net.PacketConn, filters *filter.Filters) *reuseConn { + if filters != nil { + conn = newFilteredConn(conn, filters) + } return &reuseConn{PacketConn: conn} } @@ -49,6 +54,8 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool { type reuseBase struct { mutex sync.Mutex + filters *filter.Filters + garbageCollectorRunning bool unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn @@ -56,8 +63,9 @@ type reuseBase struct { global map[int]*reuseConn } -func newReuseBase() reuseBase { +func newReuseBase(filters *filter.Filters) reuseBase { return reuseBase{ + filters: filters, unicast: make(map[string]map[int]*reuseConn), global: make(map[int]*reuseConn), } @@ -139,7 +147,7 @@ func (r *reuseBase) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) if err != nil { return nil, err } - rconn := newReuseConn(conn) + rconn := newReuseConn(conn, r.filters) r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn return rconn, nil } @@ -151,7 +159,7 @@ func (r *reuseBase) Listen(network string, laddr *net.UDPAddr) (*reuseConn, erro } localAddr := conn.LocalAddr().(*net.UDPAddr) - rconn := newReuseConn(conn) + rconn := newReuseConn(conn, r.filters) rconn.IncreaseCount() r.mutex.Lock() diff --git a/reuse_linux_test.go b/reuse_linux_test.go index 8bc401a..6fe77db 100644 --- a/reuse_linux_test.go +++ b/reuse_linux_test.go @@ -14,7 +14,7 @@ var _ = Describe("Reuse (on Linux)", func() { BeforeEach(func() { var err error - reuse, err = newReuse() + reuse, err = newReuse(nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/reuse_not_win.go b/reuse_not_win.go index fb36b83..57097a3 100644 --- a/reuse_not_win.go +++ b/reuse_not_win.go @@ -5,6 +5,8 @@ package libp2pquic import ( "net" + filter "github.com/libp2p/go-maddr-filter" + "github.com/vishvananda/netlink" ) @@ -14,7 +16,7 @@ type reuse struct { handle *netlink.Handle // Only set on Linux. nil on other systems. } -func newReuse() (*reuse, error) { +func newReuse(filters *filter.Filters) (*reuse, error) { handle, err := netlink.NewHandle(SupportedNlFamilies...) if err == netlink.ErrNotImplemented { handle = nil @@ -22,7 +24,7 @@ func newReuse() (*reuse, error) { return nil, err } return &reuse{ - reuseBase: newReuseBase(), + reuseBase: newReuseBase(filters), handle: handle, }, nil } diff --git a/reuse_test.go b/reuse_test.go index e773915..6a24d5d 100644 --- a/reuse_test.go +++ b/reuse_test.go @@ -37,7 +37,7 @@ var _ = Describe("Reuse", func() { BeforeEach(func() { var err error - reuse, err = newReuse() + reuse, err = newReuse(nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/reuse_win.go b/reuse_win.go index 0f57c8e..14ea1ba 100644 --- a/reuse_win.go +++ b/reuse_win.go @@ -2,14 +2,18 @@ package libp2pquic -import "net" +import ( + "net" + + filter "github.com/libp2p/go-maddr-filter" +) type reuse struct { reuseBase } -func newReuse() (*reuse, error) { - return &reuse{reuseBase: newReuseBase()}, nil +func newReuse(filters *filter.Filters) (*reuse, error) { + return &reuse{reuseBase: newReuseBase(filters)}, nil } func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) { diff --git a/transport.go b/transport.go index 6a61eae..1754c79 100644 --- a/transport.go +++ b/transport.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p-core/pnet" tpt "github.com/libp2p/go-libp2p-core/transport" p2ptls "github.com/libp2p/go-libp2p-tls" + filter "github.com/libp2p/go-maddr-filter" quic "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -36,12 +37,12 @@ type connManager struct { reuseUDP6 *reuse } -func newConnManager() (*connManager, error) { - reuseUDP4, err := newReuse() +func newConnManager(filters *filter.Filters) (*connManager, error) { + reuseUDP4, err := newReuse(filters) if err != nil { return nil, err } - reuseUDP6, err := newReuse() + reuseUDP6, err := newReuse(filters) if err != nil { return nil, err } @@ -89,7 +90,7 @@ type transport struct { var _ tpt.Transport = &transport{} // NewTransport creates a new QUIC transport -func NewTransport(key ic.PrivKey, psk pnet.PSK) (tpt.Transport, error) { +func NewTransport(key ic.PrivKey, psk pnet.PSK, filters *filter.Filters) (tpt.Transport, error) { if len(psk) > 0 { log.Error("QUIC doesn't support private networks yet.") return nil, errors.New("QUIC doesn't support private networks yet") @@ -102,7 +103,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK) (tpt.Transport, error) { if err != nil { return nil, err } - connManager, err := newConnManager() + connManager, err := newConnManager(filters) if err != nil { return nil, err }