diff --git a/go.mod b/go.mod index d2cd7c2c56..a85fea3223 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/libp2p/go-libp2p go 1.16 require ( + github.com/benbjohnson/clock v1.1.0 github.com/btcsuite/btcd v0.22.0-beta // indirect github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect github.com/gogo/protobuf v1.3.2 @@ -46,6 +47,7 @@ require ( github.com/multiformats/go-varint v0.0.6 github.com/prometheus/common v0.30.0 // indirect github.com/prometheus/procfs v0.7.3 // indirect + github.com/raulk/go-watchdog v1.2.0 github.com/stretchr/testify v1.7.0 github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9 go.uber.org/atomic v1.9.0 // indirect diff --git a/go.sum b/go.sum index f3976fc924..fa2a955dd3 100644 --- a/go.sum +++ b/go.sum @@ -102,6 +102,7 @@ github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wX github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs= github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -109,16 +110,22 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw= +github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d h1:t5Wuyh53qYyg9eqn4BbnlIT+vmhyww0TatL+zT3uWgI= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.1.0 h1:kq/SbG2BCKLkDKkjQf5OWwKWUKj1lgs3lFI4PxnR5lg= +github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -132,12 +139,16 @@ github.com/dgraph-io/badger v1.6.2/go.mod h1:JW2yswe3V058sS0kZ2h/AXeDSqFjxnZcRrV github.com/dgraph-io/ristretto v0.0.2/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/elastic/gosigar v0.12.0 h1:AsdhYCJlTudhfOYQyFNgx+fIVTfrDO0V1ST0vHgiapU= +github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs= github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -173,6 +184,8 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME= +github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -222,6 +235,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -619,6 +633,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= +github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0= +github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= @@ -682,6 +698,10 @@ github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/raulk/clock v1.1.0 h1:dpb29+UKMbLqiU/jqIJptgLR1nn23HLgMY0sTCDza5Y= +github.com/raulk/clock v1.1.0/go.mod h1:3MpVxdZ/ODBQDxbN+kzshf5OSZwPjtMDx6BBXBmOeY0= +github.com/raulk/go-watchdog v1.2.0 h1:konN75pw2BMmZ+AfuAm5rtFsWcJpKF3m02rKituuXNo= +github.com/raulk/go-watchdog v1.2.0/go.mod h1:lzSbAl5sh4rtI8tYHU01BWIDzgzqaQLj6RcA1i4mlqI= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -718,6 +738,7 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smola/gocompat v0.2.0/go.mod h1:1B0MlxbmoZNo3h8guHp8HztB3BSYR5itql9qtVc0ypY= @@ -756,6 +777,7 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1 github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/whyrusleeping/go-logging v0.0.0-20170515211332-0457bb6b88fc/go.mod h1:bopw91TMyo8J3tvftk8xmU2kPmlrt4nScJQZU2hE5EM= @@ -939,6 +961,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180810173357-98c5dad5d1a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -978,6 +1001,7 @@ golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/p2p/net/connmgr/bench_test.go b/p2p/net/connmgr/bench_test.go new file mode 100644 index 0000000000..73e25e97c5 --- /dev/null +++ b/p2p/net/connmgr/bench_test.go @@ -0,0 +1,54 @@ +package connmgr + +import ( + "math/rand" + "sync" + "testing" + + "github.com/libp2p/go-libp2p-core/network" + + "github.com/stretchr/testify/require" +) + +func randomConns(tb testing.TB) (c [5000]network.Conn) { + for i := range c { + c[i] = randConn(tb, nil) + } + return c +} + +func BenchmarkLockContention(b *testing.B) { + conns := randomConns(b) + cm, err := NewConnManager(1000, 1000, WithGracePeriod(0)) + require.NoError(b, err) + not := cm.Notifee() + + kill := make(chan struct{}) + var wg sync.WaitGroup + + for i := 0; i < 16; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-kill: + return + default: + cm.TagPeer(conns[rand.Intn(len(conns))].RemotePeer(), "another-tag", 1) + } + } + }() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rc := conns[rand.Intn(len(conns))] + not.Connected(nil, rc) + cm.TagPeer(rc.RemotePeer(), "tag", 100) + cm.UntagPeer(rc.RemotePeer(), "tag") + not.Disconnected(nil, rc) + } + close(kill) + wg.Wait() +} diff --git a/p2p/net/connmgr/connmgr.go b/p2p/net/connmgr/connmgr.go new file mode 100644 index 0000000000..9ee587015c --- /dev/null +++ b/p2p/net/connmgr/connmgr.go @@ -0,0 +1,694 @@ +package connmgr + +import ( + "context" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + logging "github.com/ipfs/go-log/v2" + ma "github.com/multiformats/go-multiaddr" +) + +var log = logging.Logger("connmgr") + +// BasicConnMgr is a ConnManager that trims connections whenever the count exceeds the +// high watermark. New connections are given a grace period before they're subject +// to trimming. Trims are automatically run on demand, only if the time from the +// previous trim is higher than 10 seconds. Furthermore, trims can be explicitly +// requested through the public interface of this struct (see TrimOpenConns). +// +// See configuration parameters in NewConnManager. +type BasicConnMgr struct { + *decayer + + cfg *config + segments segments + + plk sync.RWMutex + protected map[peer.ID]map[string]struct{} + + // channel-based semaphore that enforces only a single trim is in progress + trimMutex sync.Mutex + connCount int32 + // to be accessed atomically. This is mimicking the implementation of a sync.Once. + // Take care of correct alignment when modifying this struct. + trimCount uint64 + + lastTrimMu sync.RWMutex + lastTrim time.Time + + refCount sync.WaitGroup + ctx context.Context + cancel func() + unregisterMemoryWatcher func() +} + +var ( + _ connmgr.ConnManager = (*BasicConnMgr)(nil) + _ connmgr.Decayer = (*BasicConnMgr)(nil) +) + +type segment struct { + sync.Mutex + peers map[peer.ID]*peerInfo +} + +type segments [256]*segment + +func (ss *segments) get(p peer.ID) *segment { + return ss[byte(p[len(p)-1])] +} + +func (ss *segments) countPeers() (count int) { + for _, seg := range ss { + seg.Lock() + count += len(seg.peers) + seg.Unlock() + } + return count +} + +func (s *segment) tagInfoFor(p peer.ID) *peerInfo { + pi, ok := s.peers[p] + if ok { + return pi + } + // create a temporary peer to buffer early tags before the Connected notification arrives. + pi = &peerInfo{ + id: p, + firstSeen: time.Now(), // this timestamp will be updated when the first Connected notification arrives. + temp: true, + tags: make(map[string]int), + decaying: make(map[*decayingTag]*connmgr.DecayingValue), + conns: make(map[network.Conn]time.Time), + } + s.peers[p] = pi + return pi +} + +// NewConnManager creates a new BasicConnMgr with the provided params: +// lo and hi are watermarks governing the number of connections that'll be maintained. +// When the peer count exceeds the 'high watermark', as many peers will be pruned (and +// their connections terminated) until 'low watermark' peers remain. +func NewConnManager(low, hi int, opts ...Option) (*BasicConnMgr, error) { + cfg := &config{ + highWater: hi, + lowWater: low, + gracePeriod: time.Minute, + silencePeriod: 10 * time.Second, + } + for _, o := range opts { + if err := o(cfg); err != nil { + return nil, err + } + } + + if cfg.decayer == nil { + // Set the default decayer config. + cfg.decayer = (&DecayerCfg{}).WithDefaults() + } + + cm := &BasicConnMgr{ + cfg: cfg, + protected: make(map[peer.ID]map[string]struct{}, 16), + segments: func() (ret segments) { + for i := range ret { + ret[i] = &segment{ + peers: make(map[peer.ID]*peerInfo), + } + } + return ret + }(), + } + cm.ctx, cm.cancel = context.WithCancel(context.Background()) + + if cfg.emergencyTrim { + // When we're running low on memory, immediately trigger a trim. + cm.unregisterMemoryWatcher = registerWatchdog(cm.memoryEmergency) + } + + decay, _ := NewDecayer(cfg.decayer, cm) + cm.decayer = decay + + cm.refCount.Add(1) + go cm.background() + return cm, nil +} + +// memoryEmergency is run when we run low on memory. +// Close connections until we right the low watermark. +// We don't pay attention to the silence period or the grace period. +// We try to not kill protected connections, but if that turns out to be necessary, not connection is safe! +func (cm *BasicConnMgr) memoryEmergency() { + connCount := int(atomic.LoadInt32(&cm.connCount)) + target := connCount - cm.cfg.lowWater + if target < 0 { + log.Warnw("Low on memory, but we only have a few connections", "num", connCount, "low watermark", cm.cfg.lowWater) + return + } else { + log.Warnf("Low on memory. Closing %d connections.", target) + } + + cm.trimMutex.Lock() + defer atomic.AddUint64(&cm.trimCount, 1) + defer cm.trimMutex.Unlock() + + // Trim connections without paying attention to the silence period. + for _, c := range cm.getConnsToCloseEmergency(target) { + log.Infow("low on memory. closing conn", "peer", c.RemotePeer()) + c.Close() + } + + // finally, update the last trim time. + cm.lastTrimMu.Lock() + cm.lastTrim = time.Now() + cm.lastTrimMu.Unlock() +} + +func (cm *BasicConnMgr) Close() error { + cm.cancel() + if cm.unregisterMemoryWatcher != nil { + cm.unregisterMemoryWatcher() + } + if err := cm.decayer.Close(); err != nil { + return err + } + cm.refCount.Wait() + return nil +} + +func (cm *BasicConnMgr) Protect(id peer.ID, tag string) { + cm.plk.Lock() + defer cm.plk.Unlock() + + tags, ok := cm.protected[id] + if !ok { + tags = make(map[string]struct{}, 2) + cm.protected[id] = tags + } + tags[tag] = struct{}{} +} + +func (cm *BasicConnMgr) Unprotect(id peer.ID, tag string) (protected bool) { + cm.plk.Lock() + defer cm.plk.Unlock() + + tags, ok := cm.protected[id] + if !ok { + return false + } + if delete(tags, tag); len(tags) == 0 { + delete(cm.protected, id) + return false + } + return true +} + +func (cm *BasicConnMgr) IsProtected(id peer.ID, tag string) (protected bool) { + cm.plk.Lock() + defer cm.plk.Unlock() + + tags, ok := cm.protected[id] + if !ok { + return false + } + + if tag == "" { + return true + } + + _, protected = tags[tag] + return protected +} + +// peerInfo stores metadata for a given peer. +type peerInfo struct { + id peer.ID + tags map[string]int // value for each tag + decaying map[*decayingTag]*connmgr.DecayingValue // decaying tags + + value int // cached sum of all tag values + temp bool // this is a temporary entry holding early tags, and awaiting connections + + conns map[network.Conn]time.Time // start time of each connection + + firstSeen time.Time // timestamp when we began tracking this peer. +} + +type peerInfos []peerInfo + +func (p peerInfos) SortByValue() { + sort.Slice(p, func(i, j int) bool { + left, right := p[i], p[j] + // temporary peers are preferred for pruning. + if left.temp != right.temp { + return left.temp + } + // otherwise, compare by value. + return left.value < right.value + }) +} + +func (p peerInfos) SortByValueAndStreams() { + sort.Slice(p, func(i, j int) bool { + left, right := p[i], p[j] + // temporary peers are preferred for pruning. + if left.temp != right.temp { + return left.temp + } + // otherwise, compare by value. + if left.value != right.value { + return left.value < right.value + } + incomingAndStreams := func(m map[network.Conn]time.Time) (incoming bool, numStreams int) { + for c := range m { + stat := c.Stat() + if stat.Direction == network.DirInbound { + incoming = true + } + numStreams += stat.NumStreams + } + return + } + leftIncoming, leftStreams := incomingAndStreams(left.conns) + rightIncoming, rightStreams := incomingAndStreams(right.conns) + // incoming connections are preferred for pruning + if leftIncoming != rightIncoming { + return leftIncoming + } + // prune connections with a higher number of streams first + return rightStreams < leftStreams + }) +} + +// TrimOpenConns closes the connections of as many peers as needed to make the peer count +// equal the low watermark. Peers are sorted in ascending order based on their total value, +// pruning those peers with the lowest scores first, as long as they are not within their +// grace period. +// +// This function blocks until a trim is completed. If a trim is underway, a new +// one won't be started, and instead it'll wait until that one is completed before +// returning. +func (cm *BasicConnMgr) TrimOpenConns(_ context.Context) { + // TODO: error return value so we can cleanly signal we are aborting because: + // (a) there's another trim in progress, or (b) the silence period is in effect. + + cm.doTrim() +} + +func (cm *BasicConnMgr) background() { + defer cm.refCount.Done() + + interval := cm.cfg.gracePeriod / 2 + if cm.cfg.silencePeriod != 0 { + interval = cm.cfg.silencePeriod + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if atomic.LoadInt32(&cm.connCount) < int32(cm.cfg.highWater) { + // Below high water, skip. + continue + } + case <-cm.ctx.Done(): + return + } + cm.trim() + } +} + +func (cm *BasicConnMgr) doTrim() { + // This logic is mimicking the implementation of sync.Once in the standard library. + count := atomic.LoadUint64(&cm.trimCount) + cm.trimMutex.Lock() + defer cm.trimMutex.Unlock() + if count == atomic.LoadUint64(&cm.trimCount) { + cm.trim() + cm.lastTrimMu.Lock() + cm.lastTrim = time.Now() + cm.lastTrimMu.Unlock() + atomic.AddUint64(&cm.trimCount, 1) + } +} + +// trim starts the trim, if the last trim happened before the configured silence period. +func (cm *BasicConnMgr) trim() { + // do the actual trim. + for _, c := range cm.getConnsToClose() { + log.Infow("closing conn", "peer", c.RemotePeer()) + c.Close() + } +} + +func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn { + candidates := make(peerInfos, 0, cm.segments.countPeers()) + + cm.plk.RLock() + for _, s := range cm.segments { + s.Lock() + for id, inf := range s.peers { + if _, ok := cm.protected[id]; ok { + // skip over protected peer. + continue + } + candidates = append(candidates, *inf) + } + s.Unlock() + } + cm.plk.RUnlock() + + // Sort peers according to their value. + candidates.SortByValueAndStreams() + + selected := make([]network.Conn, 0, target+10) + for _, inf := range candidates { + if target <= 0 { + break + } + for c := range inf.conns { + selected = append(selected, c) + } + target -= len(inf.conns) + } + if len(selected) >= target { + // We found enough connections that were not protected. + return selected + } + + // We didn't find enough unprotected connections. + // We have no choice but to kill some protected connections. + candidates = candidates[:0] + cm.plk.RLock() + for _, s := range cm.segments { + s.Lock() + for _, inf := range s.peers { + candidates = append(candidates, *inf) + } + s.Unlock() + } + cm.plk.RUnlock() + + candidates.SortByValueAndStreams() + for _, inf := range candidates { + if target <= 0 { + break + } + for c := range inf.conns { + selected = append(selected, c) + } + target -= len(inf.conns) + } + return selected +} + +// getConnsToClose runs the heuristics described in TrimOpenConns and returns the +// connections to close. +func (cm *BasicConnMgr) getConnsToClose() []network.Conn { + if cm.cfg.lowWater == 0 || cm.cfg.highWater == 0 { + // disabled + return nil + } + + if int(atomic.LoadInt32(&cm.connCount)) <= cm.cfg.lowWater { + log.Info("open connection count below limit") + return nil + } + + candidates := make(peerInfos, 0, cm.segments.countPeers()) + var ncandidates int + gracePeriodStart := time.Now().Add(-cm.cfg.gracePeriod) + + cm.plk.RLock() + for _, s := range cm.segments { + s.Lock() + for id, inf := range s.peers { + if _, ok := cm.protected[id]; ok { + // skip over protected peer. + continue + } + if inf.firstSeen.After(gracePeriodStart) { + // skip peers in the grace period. + continue + } + // note that we're copying the entry here, + // but since inf.conns is a map, it will still point to the original object + candidates = append(candidates, *inf) + ncandidates += len(inf.conns) + } + s.Unlock() + } + cm.plk.RUnlock() + + if ncandidates < cm.cfg.lowWater { + log.Info("open connection count above limit but too many are in the grace period") + // We have too many connections but fewer than lowWater + // connections out of the grace period. + // + // If we trimmed now, we'd kill potentially useful connections. + return nil + } + + // Sort peers according to their value. + candidates.SortByValue() + + target := ncandidates - cm.cfg.lowWater + + // slightly overallocate because we may have more than one conns per peer + selected := make([]network.Conn, 0, target+10) + + for _, inf := range candidates { + if target <= 0 { + break + } + + // lock this to protect from concurrent modifications from connect/disconnect events + s := cm.segments.get(inf.id) + s.Lock() + if len(inf.conns) == 0 && inf.temp { + // handle temporary entries for early tags -- this entry has gone past the grace period + // and still holds no connections, so prune it. + delete(s.peers, inf.id) + } else { + for c := range inf.conns { + selected = append(selected, c) + } + target -= len(inf.conns) + } + s.Unlock() + } + + return selected +} + +// GetTagInfo is called to fetch the tag information associated with a given +// peer, nil is returned if p refers to an unknown peer. +func (cm *BasicConnMgr) GetTagInfo(p peer.ID) *connmgr.TagInfo { + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + pi, ok := s.peers[p] + if !ok { + return nil + } + + out := &connmgr.TagInfo{ + FirstSeen: pi.firstSeen, + Value: pi.value, + Tags: make(map[string]int), + Conns: make(map[string]time.Time), + } + + for t, v := range pi.tags { + out.Tags[t] = v + } + for t, v := range pi.decaying { + out.Tags[t.name] = v.Value + } + for c, t := range pi.conns { + out.Conns[c.RemoteMultiaddr().String()] = t + } + + return out +} + +// TagPeer is called to associate a string and integer with a given peer. +func (cm *BasicConnMgr) TagPeer(p peer.ID, tag string, val int) { + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + pi := s.tagInfoFor(p) + + // Update the total value of the peer. + pi.value += val - pi.tags[tag] + pi.tags[tag] = val +} + +// UntagPeer is called to disassociate a string and integer from a given peer. +func (cm *BasicConnMgr) UntagPeer(p peer.ID, tag string) { + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + pi, ok := s.peers[p] + if !ok { + log.Info("tried to remove tag from untracked peer: ", p) + return + } + + // Update the total value of the peer. + pi.value -= pi.tags[tag] + delete(pi.tags, tag) +} + +// UpsertTag is called to insert/update a peer tag +func (cm *BasicConnMgr) UpsertTag(p peer.ID, tag string, upsert func(int) int) { + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + pi := s.tagInfoFor(p) + + oldval := pi.tags[tag] + newval := upsert(oldval) + pi.value += newval - oldval + pi.tags[tag] = newval +} + +// CMInfo holds the configuration for BasicConnMgr, as well as status data. +type CMInfo struct { + // The low watermark, as described in NewConnManager. + LowWater int + + // The high watermark, as described in NewConnManager. + HighWater int + + // The timestamp when the last trim was triggered. + LastTrim time.Time + + // The configured grace period, as described in NewConnManager. + GracePeriod time.Duration + + // The current connection count. + ConnCount int +} + +// GetInfo returns the configuration and status data for this connection manager. +func (cm *BasicConnMgr) GetInfo() CMInfo { + cm.lastTrimMu.RLock() + lastTrim := cm.lastTrim + cm.lastTrimMu.RUnlock() + + return CMInfo{ + HighWater: cm.cfg.highWater, + LowWater: cm.cfg.lowWater, + LastTrim: lastTrim, + GracePeriod: cm.cfg.gracePeriod, + ConnCount: int(atomic.LoadInt32(&cm.connCount)), + } +} + +// Notifee returns a sink through which Notifiers can inform the BasicConnMgr when +// events occur. Currently, the notifee only reacts upon connection events +// {Connected, Disconnected}. +func (cm *BasicConnMgr) Notifee() network.Notifiee { + return (*cmNotifee)(cm) +} + +type cmNotifee BasicConnMgr + +func (nn *cmNotifee) cm() *BasicConnMgr { + return (*BasicConnMgr)(nn) +} + +// Connected is called by notifiers to inform that a new connection has been established. +// The notifee updates the BasicConnMgr to start tracking the connection. If the new connection +// count exceeds the high watermark, a trim may be triggered. +func (nn *cmNotifee) Connected(n network.Network, c network.Conn) { + cm := nn.cm() + + p := c.RemotePeer() + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + id := c.RemotePeer() + pinfo, ok := s.peers[id] + if !ok { + pinfo = &peerInfo{ + id: id, + firstSeen: time.Now(), + tags: make(map[string]int), + decaying: make(map[*decayingTag]*connmgr.DecayingValue), + conns: make(map[network.Conn]time.Time), + } + s.peers[id] = pinfo + } else if pinfo.temp { + // we had created a temporary entry for this peer to buffer early tags before the + // Connected notification arrived: flip the temporary flag, and update the firstSeen + // timestamp to the real one. + pinfo.temp = false + pinfo.firstSeen = time.Now() + } + + _, ok = pinfo.conns[c] + if ok { + log.Error("received connected notification for conn we are already tracking: ", p) + return + } + + pinfo.conns[c] = time.Now() + atomic.AddInt32(&cm.connCount, 1) +} + +// Disconnected is called by notifiers to inform that an existing connection has been closed or terminated. +// The notifee updates the BasicConnMgr accordingly to stop tracking the connection, and performs housekeeping. +func (nn *cmNotifee) Disconnected(n network.Network, c network.Conn) { + cm := nn.cm() + + p := c.RemotePeer() + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + cinf, ok := s.peers[p] + if !ok { + log.Error("received disconnected notification for peer we are not tracking: ", p) + return + } + + _, ok = cinf.conns[c] + if !ok { + log.Error("received disconnected notification for conn we are not tracking: ", p) + return + } + + delete(cinf.conns, c) + if len(cinf.conns) == 0 { + delete(s.peers, p) + } + atomic.AddInt32(&cm.connCount, -1) +} + +// Listen is no-op in this implementation. +func (nn *cmNotifee) Listen(n network.Network, addr ma.Multiaddr) {} + +// ListenClose is no-op in this implementation. +func (nn *cmNotifee) ListenClose(n network.Network, addr ma.Multiaddr) {} + +// OpenedStream is no-op in this implementation. +func (nn *cmNotifee) OpenedStream(network.Network, network.Stream) {} + +// ClosedStream is no-op in this implementation. +func (nn *cmNotifee) ClosedStream(network.Network, network.Stream) {} diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go new file mode 100644 index 0000000000..390c8599d5 --- /dev/null +++ b/p2p/net/connmgr/connmgr_test.go @@ -0,0 +1,857 @@ +package connmgr + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/crypto" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + tu "github.com/libp2p/go-libp2p-core/test" + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +type tconn struct { + network.Conn + + peer peer.ID + closed uint32 // to be used atomically. Closed if 1 + disconnectNotify func(net network.Network, conn network.Conn) +} + +func (c *tconn) Close() error { + atomic.StoreUint32(&c.closed, 1) + if c.disconnectNotify != nil { + c.disconnectNotify(nil, c) + } + return nil +} + +func (c *tconn) isClosed() bool { + return atomic.LoadUint32(&c.closed) == 1 +} + +func (c *tconn) RemotePeer() peer.ID { + return c.peer +} + +func (c *tconn) RemoteMultiaddr() ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234") + if err != nil { + panic("cannot create multiaddr") + } + return addr +} + +func randConn(t testing.TB, discNotify func(network.Network, network.Conn)) network.Conn { + pid := tu.RandPeerIDFatal(t) + return &tconn{peer: pid, disconnectNotify: discNotify} +} + +// Make sure multiple trim calls block. +func TestTrimBlocks(t *testing.T) { + cm, err := NewConnManager(200, 300, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + + cm.lastTrimMu.RLock() + + doneCh := make(chan struct{}, 2) + go func() { + cm.TrimOpenConns(context.Background()) + doneCh <- struct{}{} + }() + go func() { + cm.TrimOpenConns(context.Background()) + doneCh <- struct{}{} + }() + time.Sleep(time.Millisecond) + select { + case <-doneCh: + cm.lastTrimMu.RUnlock() + t.Fatal("expected trim to block") + default: + cm.lastTrimMu.RUnlock() + } + <-doneCh + <-doneCh +} + +// Make sure trim returns when closed. +func TestTrimClosed(t *testing.T) { + cm, err := NewConnManager(200, 300, WithGracePeriod(0)) + require.NoError(t, err) + require.NoError(t, cm.Close()) + cm.TrimOpenConns(context.Background()) +} + +// Make sure joining an existing trim works. +func TestTrimJoin(t *testing.T) { + cm, err := NewConnManager(200, 300, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + + cm.lastTrimMu.RLock() + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + cm.TrimOpenConns(context.Background()) + }() + time.Sleep(time.Millisecond) + go func() { + defer wg.Done() + cm.TrimOpenConns(context.Background()) + }() + go func() { + defer wg.Done() + cm.TrimOpenConns(context.Background()) + }() + time.Sleep(time.Millisecond) + cm.lastTrimMu.RUnlock() + wg.Wait() +} + +func TestConnTrimming(t *testing.T) { + cm, err := NewConnManager(200, 300, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + + var conns []network.Conn + for i := 0; i < 300; i++ { + rc := randConn(t, nil) + conns = append(conns, rc) + not.Connected(nil, rc) + } + + for _, c := range conns { + if c.(*tconn).isClosed() { + t.Fatal("nothing should be closed yet") + } + } + + for i := 0; i < 100; i++ { + cm.TagPeer(conns[i].RemotePeer(), "foo", 10) + } + + cm.TagPeer(conns[299].RemotePeer(), "badfoo", -5) + + cm.TrimOpenConns(context.Background()) + + for i := 0; i < 100; i++ { + c := conns[i] + if c.(*tconn).isClosed() { + t.Fatal("these shouldnt be closed") + } + } + + if !conns[299].(*tconn).isClosed() { + t.Fatal("conn with bad tag should have gotten closed") + } +} + +func TestConnsToClose(t *testing.T) { + addConns := func(cm *BasicConnMgr, n int) { + not := cm.Notifee() + for i := 0; i < n; i++ { + conn := randConn(t, nil) + not.Connected(nil, conn) + } + } + + t.Run("below hi limit", func(t *testing.T) { + cm, err := NewConnManager(0, 10, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + addConns(cm, 5) + require.Empty(t, cm.getConnsToClose()) + }) + + t.Run("below low limit", func(t *testing.T) { + cm, err := NewConnManager(10, 0, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + addConns(cm, 5) + require.Empty(t, cm.getConnsToClose()) + }) + + t.Run("below low and hi limit", func(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + addConns(cm, 1) + require.Empty(t, cm.getConnsToClose()) + }) + + t.Run("within silence period", func(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) + require.NoError(t, err) + defer cm.Close() + addConns(cm, 1) + require.Empty(t, cm.getConnsToClose()) + }) +} + +func TestGetTagInfo(t *testing.T) { + start := time.Now() + cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) + require.NoError(t, err) + defer cm.Close() + + not := cm.Notifee() + conn := randConn(t, nil) + not.Connected(nil, conn) + end := time.Now() + + other := tu.RandPeerIDFatal(t) + tag := cm.GetTagInfo(other) + if tag != nil { + t.Fatal("expected no tag") + } + + tag = cm.GetTagInfo(conn.RemotePeer()) + if tag == nil { + t.Fatal("expected tag") + } + if tag.FirstSeen.Before(start) || tag.FirstSeen.After(end) { + t.Fatal("expected first seen time") + } + if tag.Value != 0 { + t.Fatal("expected zero value") + } + if len(tag.Tags) != 0 { + t.Fatal("expected no tags") + } + if len(tag.Conns) != 1 { + t.Fatal("expected one connection") + } + for s, tm := range tag.Conns { + if s != conn.RemoteMultiaddr().String() { + t.Fatal("unexpected multiaddr") + } + if tm.Before(start) || tm.After(end) { + t.Fatal("unexpected connection time") + } + } + + cm.TagPeer(conn.RemotePeer(), "tag", 5) + tag = cm.GetTagInfo(conn.RemotePeer()) + if tag == nil { + t.Fatal("expected tag") + } + if tag.FirstSeen.Before(start) || tag.FirstSeen.After(end) { + t.Fatal("expected first seen time") + } + if tag.Value != 5 { + t.Fatal("expected five value") + } + if len(tag.Tags) != 1 { + t.Fatal("expected no tags") + } + for tString, v := range tag.Tags { + if tString != "tag" || v != 5 { + t.Fatal("expected tag value") + } + } + if len(tag.Conns) != 1 { + t.Fatal("expected one connection") + } + for s, tm := range tag.Conns { + if s != conn.RemoteMultiaddr().String() { + t.Fatal("unexpected multiaddr") + } + if tm.Before(start) || tm.After(end) { + t.Fatal("unexpected connection time") + } + } +} + +func TestTagPeerNonExistant(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) + require.NoError(t, err) + defer cm.Close() + + id := tu.RandPeerIDFatal(t) + cm.TagPeer(id, "test", 1) + + if !cm.segments.get(id).peers[id].temp { + t.Fatal("expected 1 temporary entry") + } +} + +func TestUntagPeer(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + + conn := randConn(t, nil) + not.Connected(nil, conn) + rp := conn.RemotePeer() + cm.TagPeer(rp, "tag", 5) + cm.TagPeer(rp, "tag two", 5) + + id := tu.RandPeerIDFatal(t) + cm.UntagPeer(id, "test") + if len(cm.segments.get(rp).peers[rp].tags) != 2 { + t.Fatal("expected tags to be uneffected") + } + + cm.UntagPeer(conn.RemotePeer(), "test") + if len(cm.segments.get(rp).peers[rp].tags) != 2 { + t.Fatal("expected tags to be uneffected") + } + + cm.UntagPeer(conn.RemotePeer(), "tag") + if len(cm.segments.get(rp).peers[rp].tags) != 1 { + t.Fatal("expected tag to be removed") + } + if cm.segments.get(rp).peers[rp].value != 5 { + t.Fatal("expected aggreagte tag value to be 5") + } +} + +func TestGetInfo(t *testing.T) { + start := time.Now() + const gp = 10 * time.Minute + cm, err := NewConnManager(1, 5, WithGracePeriod(gp)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + conn := randConn(t, nil) + not.Connected(nil, conn) + cm.TrimOpenConns(context.Background()) + end := time.Now() + + info := cm.GetInfo() + if info.HighWater != 5 { + t.Fatal("expected highwater to be 5") + } + if info.LowWater != 1 { + t.Fatal("expected highwater to be 1") + } + if info.LastTrim.Before(start) || info.LastTrim.After(end) { + t.Fatal("unexpected last trim time") + } + if info.GracePeriod != gp { + t.Fatal("unexpected grace period") + } + if info.ConnCount != 1 { + t.Fatal("unexpected number of connections") + } +} + +func TestDoubleConnection(t *testing.T) { + const gp = 10 * time.Minute + cm, err := NewConnManager(1, 5, WithGracePeriod(gp)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + conn := randConn(t, nil) + not.Connected(nil, conn) + cm.TagPeer(conn.RemotePeer(), "foo", 10) + not.Connected(nil, conn) + if cm.connCount != 1 { + t.Fatal("unexpected number of connections") + } + if cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()].value != 10 { + t.Fatal("unexpected peer value") + } +} + +func TestDisconnected(t *testing.T) { + const gp = 10 * time.Minute + cm, err := NewConnManager(1, 5, WithGracePeriod(gp)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + conn := randConn(t, nil) + not.Connected(nil, conn) + cm.TagPeer(conn.RemotePeer(), "foo", 10) + + not.Disconnected(nil, randConn(t, nil)) + if cm.connCount != 1 { + t.Fatal("unexpected number of connections") + } + if cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()].value != 10 { + t.Fatal("unexpected peer value") + } + + not.Disconnected(nil, &tconn{peer: conn.RemotePeer()}) + if cm.connCount != 1 { + t.Fatal("unexpected number of connections") + } + if cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()].value != 10 { + t.Fatal("unexpected peer value") + } + + not.Disconnected(nil, conn) + if cm.connCount != 0 { + t.Fatal("unexpected number of connections") + } + if cm.segments.countPeers() != 0 { + t.Fatal("unexpected number of peers") + } +} + +func TestGracePeriod(t *testing.T) { + const gp = 100 * time.Millisecond + cm, err := NewConnManager(10, 20, WithGracePeriod(gp), WithSilencePeriod(time.Hour)) + require.NoError(t, err) + defer cm.Close() + + not := cm.Notifee() + + var conns []network.Conn + + // Add a connection and wait the grace period. + { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + + time.Sleep(2 * gp) + + if rc.(*tconn).isClosed() { + t.Fatal("expected conn to remain open") + } + } + + // quickly add 30 connections (sending us above the high watermark) + for i := 0; i < 30; i++ { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + } + + cm.TrimOpenConns(context.Background()) + + for _, c := range conns { + if c.(*tconn).isClosed() { + t.Fatal("expected no conns to be closed") + } + } + + time.Sleep(200 * time.Millisecond) + + cm.TrimOpenConns(context.Background()) + + closed := 0 + for _, c := range conns { + if c.(*tconn).isClosed() { + closed++ + } + } + + if closed != 21 { + t.Fatal("expected to have closed 21 connections") + } +} + +// see https://github.com/libp2p/go-libp2p-connmgr/issues/23 +func TestQuickBurstRespectsSilencePeriod(t *testing.T) { + cm, err := NewConnManager(10, 20, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + + var conns []network.Conn + + // quickly produce 30 connections (sending us above the high watermark) + for i := 0; i < 30; i++ { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + } + + // wait for a few seconds + time.Sleep(time.Second * 3) + + // only the first trim is allowed in; make sure we close at most 20 connections, not all of them. + var closed int + for _, c := range conns { + if c.(*tconn).isClosed() { + closed++ + } + } + if closed > 20 { + t.Fatalf("should have closed at most 20 connections, closed: %d", closed) + } + if total := closed + int(cm.connCount); total != 30 { + t.Fatalf("expected closed connections + open conn count to equal 30, value: %d", total) + } +} + +func TestPeerProtectionSingleTag(t *testing.T) { + cm, err := NewConnManager(19, 20, WithGracePeriod(0), WithSilencePeriod(time.Hour)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + + var conns []network.Conn + addConn := func(value int) { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + cm.TagPeer(rc.RemotePeer(), "test", value) + } + + // produce 20 connections with unique peers. + for i := 0; i < 20; i++ { + addConn(20) + } + + // protect the first 5 peers. + var protected []network.Conn + for _, c := range conns[0:5] { + cm.Protect(c.RemotePeer(), "global") + protected = append(protected, c) + // tag them negatively to make them preferred for pruning. + cm.TagPeer(c.RemotePeer(), "test", -100) + } + + // add 1 more conn, this shouldn't send us over the limit as protected conns don't count + addConn(20) + + time.Sleep(100 * time.Millisecond) + cm.TrimOpenConns(context.Background()) + + for _, c := range conns { + if c.(*tconn).isClosed() { + t.Error("connection was closed by connection manager") + } + } + + // add 5 more connection, sending the connection manager overboard. + for i := 0; i < 5; i++ { + addConn(20) + } + + cm.TrimOpenConns(context.Background()) + + for _, c := range protected { + if c.(*tconn).isClosed() { + t.Error("protected connection was closed by connection manager") + } + } + + closed := 0 + for _, c := range conns { + if c.(*tconn).isClosed() { + closed++ + } + } + if closed != 2 { + t.Errorf("expected 2 connection to be closed, found %d", closed) + } + + // unprotect the first peer. + cm.Unprotect(protected[0].RemotePeer(), "global") + + // add 2 more connections, sending the connection manager overboard again. + for i := 0; i < 2; i++ { + addConn(20) + } + + cm.TrimOpenConns(context.Background()) + + if !protected[0].(*tconn).isClosed() { + t.Error("unprotected connection was kept open by connection manager") + } + for _, c := range protected[1:] { + if c.(*tconn).isClosed() { + t.Error("protected connection was closed by connection manager") + } + } +} + +func TestPeerProtectionMultipleTags(t *testing.T) { + cm, err := NewConnManager(19, 20, WithGracePeriod(0), WithSilencePeriod(time.Hour)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + + // produce 20 connections with unique peers. + var conns []network.Conn + for i := 0; i < 20; i++ { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + cm.TagPeer(rc.RemotePeer(), "test", 20) + } + + // protect the first 5 peers under two tags. + var protected []network.Conn + for _, c := range conns[0:5] { + cm.Protect(c.RemotePeer(), "tag1") + cm.Protect(c.RemotePeer(), "tag2") + protected = append(protected, c) + // tag them negatively to make them preferred for pruning. + cm.TagPeer(c.RemotePeer(), "test", -100) + } + + // add one more connection, sending the connection manager overboard. + not.Connected(nil, randConn(t, not.Disconnected)) + + cm.TrimOpenConns(context.Background()) + + for _, c := range protected { + if c.(*tconn).isClosed() { + t.Error("protected connection was closed by connection manager") + } + } + + // remove the protection from one tag. + for _, c := range protected { + if !cm.Unprotect(c.RemotePeer(), "tag1") { + t.Error("peer should still be protected") + } + } + + // add 2 more connections, sending the connection manager overboard again. + for i := 0; i < 2; i++ { + rc := randConn(t, not.Disconnected) + not.Connected(nil, rc) + cm.TagPeer(rc.RemotePeer(), "test", 20) + } + + cm.TrimOpenConns(context.Background()) + + // connections should still remain open, as they were protected. + for _, c := range protected[0:] { + if c.(*tconn).isClosed() { + t.Error("protected connection was closed by connection manager") + } + } + + // unprotect the first peer entirely. + cm.Unprotect(protected[0].RemotePeer(), "tag2") + + // add 2 more connections, sending the connection manager overboard again. + for i := 0; i < 2; i++ { + rc := randConn(t, not.Disconnected) + not.Connected(nil, rc) + cm.TagPeer(rc.RemotePeer(), "test", 20) + } + + cm.TrimOpenConns(context.Background()) + + if !protected[0].(*tconn).isClosed() { + t.Error("unprotected connection was kept open by connection manager") + } + for _, c := range protected[1:] { + if c.(*tconn).isClosed() { + t.Error("protected connection was closed by connection manager") + } + } + +} + +func TestPeerProtectionIdempotent(t *testing.T) { + cm, err := NewConnManager(10, 20, WithGracePeriod(0), WithSilencePeriod(time.Hour)) + require.NoError(t, err) + defer cm.Close() + + id, _ := tu.RandPeerID() + cm.Protect(id, "global") + cm.Protect(id, "global") + cm.Protect(id, "global") + cm.Protect(id, "global") + + if len(cm.protected[id]) > 1 { + t.Error("expected peer to be protected only once") + } + + if !cm.Unprotect(id, "unused") { + t.Error("expected peer to continue to be protected") + } + + if !cm.Unprotect(id, "unused2") { + t.Error("expected peer to continue to be protected") + } + + if cm.Unprotect(id, "global") { + t.Error("expected peer to be unprotected") + } + + if len(cm.protected) > 0 { + t.Error("expected no protections") + } +} + +func TestUpsertTag(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) + require.NoError(t, err) + defer cm.Close() + not := cm.Notifee() + conn := randConn(t, nil) + rp := conn.RemotePeer() + + // this is an early tag, before the Connected notification arrived. + cm.UpsertTag(rp, "tag", func(v int) int { return v + 1 }) + if len(cm.segments.get(rp).peers[rp].tags) != 1 { + t.Fatal("expected a tag") + } + if cm.segments.get(rp).peers[rp].value != 1 { + t.Fatal("expected a tag value of 1") + } + + // now let's notify the connection. + not.Connected(nil, conn) + + cm.UpsertTag(rp, "tag", func(v int) int { return v + 1 }) + if len(cm.segments.get(rp).peers[rp].tags) != 1 { + t.Fatal("expected a tag") + } + if cm.segments.get(rp).peers[rp].value != 2 { + t.Fatal("expected a tag value of 2") + } + + cm.UpsertTag(rp, "tag", func(v int) int { return v - 1 }) + if len(cm.segments.get(rp).peers[rp].tags) != 1 { + t.Fatal("expected a tag") + } + if cm.segments.get(rp).peers[rp].value != 1 { + t.Fatal("expected a tag value of 1") + } +} + +func TestTemporaryEntriesClearedFirst(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(0)) + require.NoError(t, err) + + id := tu.RandPeerIDFatal(t) + cm.TagPeer(id, "test", 20) + + if cm.GetTagInfo(id).Value != 20 { + t.Fatal("expected an early tag with value 20") + } + + not := cm.Notifee() + conn1, conn2 := randConn(t, nil), randConn(t, nil) + not.Connected(nil, conn1) + not.Connected(nil, conn2) + + cm.TrimOpenConns(context.Background()) + if cm.GetTagInfo(id) != nil { + t.Fatal("expected no temporary tags after trimming") + } +} + +func TestTemporaryEntryConvertedOnConnection(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(0)) + require.NoError(t, err) + defer cm.Close() + + conn := randConn(t, nil) + cm.TagPeer(conn.RemotePeer(), "test", 20) + + ti := cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()] + + if ti.value != 20 || !ti.temp { + t.Fatal("expected a temporary tag with value 20") + } + + not := cm.Notifee() + not.Connected(nil, conn) + + if ti.value != 20 || ti.temp { + t.Fatal("expected a non-temporary tag with value 20") + } +} + +// see https://github.com/libp2p/go-libp2p-connmgr/issues/82 +func TestConcurrentCleanupAndTagging(t *testing.T) { + cm, err := NewConnManager(1, 1, WithGracePeriod(0), WithSilencePeriod(time.Millisecond)) + require.NoError(t, err) + defer cm.Close() + + for i := 0; i < 1000; i++ { + conn := randConn(t, nil) + cm.TagPeer(conn.RemotePeer(), "test", 20) + } +} + +type mockConn struct { + stats network.ConnStats +} + +func (m mockConn) Close() error { panic("implement me") } +func (m mockConn) LocalPeer() peer.ID { panic("implement me") } +func (m mockConn) LocalPrivateKey() crypto.PrivKey { panic("implement me") } +func (m mockConn) RemotePeer() peer.ID { panic("implement me") } +func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") } +func (m mockConn) LocalMultiaddr() ma.Multiaddr { panic("implement me") } +func (m mockConn) RemoteMultiaddr() ma.Multiaddr { panic("implement me") } +func (m mockConn) Stat() network.ConnStats { return m.stats } +func (m mockConn) ID() string { panic("implement me") } +func (m mockConn) NewStream(ctx context.Context) (network.Stream, error) { panic("implement me") } +func (m mockConn) GetStreams() []network.Stream { panic("implement me") } + +func TestPeerInfoSorting(t *testing.T) { + t.Run("starts with temporary connections", func(t *testing.T) { + p1 := peerInfo{id: peer.ID("peer1")} + p2 := peerInfo{id: peer.ID("peer2"), temp: true} + pis := peerInfos{p1, p2} + pis.SortByValue() + require.Equal(t, pis, peerInfos{p2, p1}) + }) + + t.Run("starts with low-value connections", func(t *testing.T) { + p1 := peerInfo{id: peer.ID("peer1"), value: 40} + p2 := peerInfo{id: peer.ID("peer2"), value: 20} + pis := peerInfos{p1, p2} + pis.SortByValue() + require.Equal(t, pis, peerInfos{p2, p1}) + }) + + t.Run("in a memory emergency, starts with incoming connections", func(t *testing.T) { + incoming := network.ConnStats{} + incoming.Direction = network.DirInbound + outgoing := network.ConnStats{} + outgoing.Direction = network.DirOutbound + p1 := peerInfo{ + id: peer.ID("peer1"), + conns: map[network.Conn]time.Time{ + &mockConn{stats: outgoing}: time.Now(), + }, + } + p2 := peerInfo{ + id: peer.ID("peer2"), + conns: map[network.Conn]time.Time{ + &mockConn{stats: outgoing}: time.Now(), + &mockConn{stats: incoming}: time.Now(), + }, + } + pis := peerInfos{p1, p2} + pis.SortByValueAndStreams() + require.Equal(t, pis, peerInfos{p2, p1}) + }) + + t.Run("in a memory emergency, starts with connections that have many streams", func(t *testing.T) { + p1 := peerInfo{ + id: peer.ID("peer1"), + conns: map[network.Conn]time.Time{ + &mockConn{stats: network.ConnStats{NumStreams: 100}}: time.Now(), + }, + } + p2 := peerInfo{ + id: peer.ID("peer2"), + conns: map[network.Conn]time.Time{ + &mockConn{stats: network.ConnStats{NumStreams: 80}}: time.Now(), + &mockConn{stats: network.ConnStats{NumStreams: 40}}: time.Now(), + }, + } + pis := peerInfos{p1, p2} + pis.SortByValueAndStreams() + require.Equal(t, pis, peerInfos{p2, p1}) + }) +} diff --git a/p2p/net/connmgr/decay.go b/p2p/net/connmgr/decay.go new file mode 100644 index 0000000000..189759a180 --- /dev/null +++ b/p2p/net/connmgr/decay.go @@ -0,0 +1,355 @@ +package connmgr + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/peer" + + "github.com/benbjohnson/clock" +) + +// DefaultResolution is the default resolution of the decay tracker. +var DefaultResolution = 1 * time.Minute + +// bumpCmd represents a bump command. +type bumpCmd struct { + peer peer.ID + tag *decayingTag + delta int +} + +// removeCmd represents a tag removal command. +type removeCmd struct { + peer peer.ID + tag *decayingTag +} + +// decayer tracks and manages all decaying tags and their values. +type decayer struct { + cfg *DecayerCfg + mgr *BasicConnMgr + clock clock.Clock // for testing. + + tagsMu sync.Mutex + knownTags map[string]*decayingTag + + // lastTick stores the last time the decayer ticked. Guarded by atomic. + lastTick atomic.Value + + // bumpTagCh queues bump commands to be processed by the loop. + bumpTagCh chan bumpCmd + removeTagCh chan removeCmd + closeTagCh chan *decayingTag + + // closure thingies. + closeCh chan struct{} + doneCh chan struct{} + err error +} + +var _ connmgr.Decayer = (*decayer)(nil) + +// DecayerCfg is the configuration object for the Decayer. +type DecayerCfg struct { + Resolution time.Duration + Clock clock.Clock +} + +// WithDefaults writes the default values on this DecayerConfig instance, +// and returns itself for chainability. +// +// cfg := (&DecayerCfg{}).WithDefaults() +// cfg.Resolution = 30 * time.Second +// t := NewDecayer(cfg, cm) +func (cfg *DecayerCfg) WithDefaults() *DecayerCfg { + cfg.Resolution = DefaultResolution + return cfg +} + +// NewDecayer creates a new decaying tag registry. +func NewDecayer(cfg *DecayerCfg, mgr *BasicConnMgr) (*decayer, error) { + // use real time if the Clock in the config is nil. + if cfg.Clock == nil { + cfg.Clock = clock.New() + } + + d := &decayer{ + cfg: cfg, + mgr: mgr, + clock: cfg.Clock, + knownTags: make(map[string]*decayingTag), + bumpTagCh: make(chan bumpCmd, 128), + removeTagCh: make(chan removeCmd, 128), + closeTagCh: make(chan *decayingTag, 128), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + + d.lastTick.Store(d.clock.Now()) + + // kick things off. + go d.process() + + return d, nil +} + +func (d *decayer) RegisterDecayingTag(name string, interval time.Duration, decayFn connmgr.DecayFn, bumpFn connmgr.BumpFn) (connmgr.DecayingTag, error) { + d.tagsMu.Lock() + defer d.tagsMu.Unlock() + + if _, ok := d.knownTags[name]; ok { + return nil, fmt.Errorf("decaying tag with name %s already exists", name) + } + + if interval < d.cfg.Resolution { + log.Warnf("decay interval for %s (%s) was lower than tracker's resolution (%s); overridden to resolution", + name, interval, d.cfg.Resolution) + interval = d.cfg.Resolution + } + + if interval%d.cfg.Resolution != 0 { + log.Warnf("decay interval for tag %s (%s) is not a multiple of tracker's resolution (%s); "+ + "some precision may be lost", name, interval, d.cfg.Resolution) + } + + lastTick := d.lastTick.Load().(time.Time) + tag := &decayingTag{ + trkr: d, + name: name, + interval: interval, + nextTick: lastTick.Add(interval), + decayFn: decayFn, + bumpFn: bumpFn, + } + + d.knownTags[name] = tag + return tag, nil +} + +// Close closes the Decayer. It is idempotent. +func (d *decayer) Close() error { + select { + case <-d.doneCh: + return d.err + default: + } + + close(d.closeCh) + <-d.doneCh + return d.err +} + +// process is the heart of the tracker. It performs the following duties: +// +// 1. Manages decay. +// 2. Applies score bumps. +// 3. Yields when closed. +func (d *decayer) process() { + defer close(d.doneCh) + + ticker := d.clock.Ticker(d.cfg.Resolution) + defer ticker.Stop() + + var ( + bmp bumpCmd + now time.Time + visit = make(map[*decayingTag]struct{}) + ) + + for { + select { + case now = <-ticker.C: + d.lastTick.Store(now) + + d.tagsMu.Lock() + for _, tag := range d.knownTags { + if tag.nextTick.After(now) { + // skip the tag. + continue + } + // Mark the tag to be updated in this round. + visit[tag] = struct{}{} + } + d.tagsMu.Unlock() + + // Visit each peer, and decay tags that need to be decayed. + for _, s := range d.mgr.segments { + s.Lock() + + // Entered a segment that contains peers. Process each peer. + for _, p := range s.peers { + for tag, v := range p.decaying { + if _, ok := visit[tag]; !ok { + // skip this tag. + continue + } + + // ~ this value needs to be visited. ~ + var delta int + if after, rm := tag.decayFn(*v); rm { + // delete the value and move on to the next tag. + delta -= v.Value + delete(p.decaying, tag) + } else { + // accumulate the delta, and apply the changes. + delta += after - v.Value + v.Value, v.LastVisit = after, now + } + p.value += delta + } + } + + s.Unlock() + } + + // Reset each tag's next visit round, and clear the visited set. + for tag := range visit { + tag.nextTick = tag.nextTick.Add(tag.interval) + delete(visit, tag) + } + + case bmp = <-d.bumpTagCh: + var ( + now = d.clock.Now() + peer, tag = bmp.peer, bmp.tag + ) + + s := d.mgr.segments.get(peer) + s.Lock() + + p := s.tagInfoFor(peer) + v, ok := p.decaying[tag] + if !ok { + v = &connmgr.DecayingValue{ + Tag: tag, + Peer: peer, + LastVisit: now, + Added: now, + Value: 0, + } + p.decaying[tag] = v + } + + prev := v.Value + v.Value, v.LastVisit = v.Tag.(*decayingTag).bumpFn(*v, bmp.delta), now + p.value += v.Value - prev + + s.Unlock() + + case rm := <-d.removeTagCh: + s := d.mgr.segments.get(rm.peer) + s.Lock() + + p := s.tagInfoFor(rm.peer) + v, ok := p.decaying[rm.tag] + if !ok { + s.Unlock() + continue + } + p.value -= v.Value + delete(p.decaying, rm.tag) + s.Unlock() + + case t := <-d.closeTagCh: + // Stop tracking the tag. + d.tagsMu.Lock() + delete(d.knownTags, t.name) + d.tagsMu.Unlock() + + // Remove the tag from all peers that had it in the connmgr. + for _, s := range d.mgr.segments { + // visit all segments, and attempt to remove the tag from all the peers it stores. + s.Lock() + for _, p := range s.peers { + if dt, ok := p.decaying[t]; ok { + // decrease the value of the tagInfo, and delete the tag. + p.value -= dt.Value + delete(p.decaying, t) + } + } + s.Unlock() + } + + case <-d.closeCh: + return + } + } +} + +// decayingTag represents a decaying tag, with an associated decay interval, a +// decay function, and a bump function. +type decayingTag struct { + trkr *decayer + name string + interval time.Duration + nextTick time.Time + decayFn connmgr.DecayFn + bumpFn connmgr.BumpFn + + // closed marks this tag as closed, so that if it's bumped after being + // closed, we can return an error. 0 = false; 1 = true; guarded by atomic. + closed int32 +} + +var _ connmgr.DecayingTag = (*decayingTag)(nil) + +func (t *decayingTag) Name() string { + return t.name +} + +func (t *decayingTag) Interval() time.Duration { + return t.interval +} + +// Bump bumps a tag for this peer. +func (t *decayingTag) Bump(p peer.ID, delta int) error { + if atomic.LoadInt32(&t.closed) == 1 { + return fmt.Errorf("decaying tag %s had been closed; no further bumps are accepted", t.name) + } + + bmp := bumpCmd{peer: p, tag: t, delta: delta} + + select { + case t.trkr.bumpTagCh <- bmp: + return nil + default: + return fmt.Errorf( + "unable to bump decaying tag for peer %s, tag %s, delta %d; queue full (len=%d)", + p.Pretty(), t.name, delta, len(t.trkr.bumpTagCh)) + } +} + +func (t *decayingTag) Remove(p peer.ID) error { + if atomic.LoadInt32(&t.closed) == 1 { + return fmt.Errorf("decaying tag %s had been closed; no further removals are accepted", t.name) + } + + rm := removeCmd{peer: p, tag: t} + + select { + case t.trkr.removeTagCh <- rm: + return nil + default: + return fmt.Errorf( + "unable to remove decaying tag for peer %s, tag %s; queue full (len=%d)", + p.Pretty(), t.name, len(t.trkr.removeTagCh)) + } +} + +func (t *decayingTag) Close() error { + if !atomic.CompareAndSwapInt32(&t.closed, 0, 1) { + log.Warnf("duplicate decaying tag closure: %s; skipping", t.name) + return nil + } + + select { + case t.trkr.closeTagCh <- t: + return nil + default: + return fmt.Errorf("unable to close decaying tag %s; queue full (len=%d)", t.name, len(t.trkr.closeTagCh)) + } +} diff --git a/p2p/net/connmgr/decay_test.go b/p2p/net/connmgr/decay_test.go new file mode 100644 index 0000000000..6c064dc2f1 --- /dev/null +++ b/p2p/net/connmgr/decay_test.go @@ -0,0 +1,421 @@ +package connmgr + +import ( + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/peer" + tu "github.com/libp2p/go-libp2p-core/test" + "github.com/stretchr/testify/require" + + "github.com/benbjohnson/clock" +) + +const TestResolution = 50 * time.Millisecond + +func TestDecayExpire(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, mockClock = testDecayTracker(t) + ) + + tag, err := decay.RegisterDecayingTag("pop", 250*time.Millisecond, connmgr.DecayExpireWhenInactive(1*time.Second), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + err = tag.Bump(id, 10) + if err != nil { + t.Fatal(err) + } + + // give time for the bump command to process. + <-time.After(100 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 10 { + t.Fatalf("wrong value; expected = %d; got = %d", 10, v) + } + + mockClock.Add(250 * time.Millisecond) + mockClock.Add(250 * time.Millisecond) + mockClock.Add(250 * time.Millisecond) + mockClock.Add(250 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 0 { + t.Fatalf("wrong value; expected = %d; got = %d", 0, v) + } +} + +func TestMultipleBumps(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, _ = testDecayTracker(t) + ) + + tag, err := decay.RegisterDecayingTag("pop", 250*time.Millisecond, connmgr.DecayExpireWhenInactive(1*time.Second), connmgr.BumpSumBounded(10, 20)) + if err != nil { + t.Fatal(err) + } + + err = tag.Bump(id, 5) + if err != nil { + t.Fatal(err) + } + + <-time.After(100 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 10 { + t.Fatalf("wrong value; expected = %d; got = %d", 10, v) + } + + err = tag.Bump(id, 100) + if err != nil { + t.Fatal(err) + } + + <-time.After(100 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 20 { + t.Fatalf("wrong value; expected = %d; got = %d", 20, v) + } +} + +func TestMultipleTagsNoDecay(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, _ = testDecayTracker(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayNone(), connmgr.BumpSumBounded(0, 100)) + if err != nil { + t.Fatal(err) + } + + tag2, err := decay.RegisterDecayingTag("bop", 250*time.Millisecond, connmgr.DecayNone(), connmgr.BumpSumBounded(0, 100)) + if err != nil { + t.Fatal(err) + } + + tag3, err := decay.RegisterDecayingTag("foo", 250*time.Millisecond, connmgr.DecayNone(), connmgr.BumpSumBounded(0, 100)) + if err != nil { + t.Fatal(err) + } + + _ = tag1.Bump(id, 100) + _ = tag2.Bump(id, 100) + _ = tag3.Bump(id, 100) + _ = tag1.Bump(id, 100) + _ = tag2.Bump(id, 100) + _ = tag3.Bump(id, 100) + + <-time.After(500 * time.Millisecond) + + // all tags are upper-bounded, so the score must be 300 + ti := mgr.GetTagInfo(id) + if v := ti.Value; v != 300 { + t.Fatalf("wrong value; expected = %d; got = %d", 300, v) + } + + for _, s := range []string{"beep", "bop", "foo"} { + if v, ok := ti.Tags[s]; !ok || v != 100 { + t.Fatalf("expected tag %s to be 100; was = %d", s, v) + } + } +} + +func TestCustomFunctions(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, mockClock = testDecayTracker(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayFixed(10), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + tag2, err := decay.RegisterDecayingTag("bop", 100*time.Millisecond, connmgr.DecayFixed(5), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + tag3, err := decay.RegisterDecayingTag("foo", 50*time.Millisecond, connmgr.DecayFixed(1), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + _ = tag1.Bump(id, 1000) + _ = tag2.Bump(id, 1000) + _ = tag3.Bump(id, 1000) + + <-time.After(500 * time.Millisecond) + + // no decay has occurred yet, so score must be 3000. + if v := mgr.GetTagInfo(id).Value; v != 3000 { + t.Fatalf("wrong value; expected = %d; got = %d", 3000, v) + } + + // only tag3 should tick. + mockClock.Add(50 * time.Millisecond) + if v := mgr.GetTagInfo(id).Value; v != 2999 { + t.Fatalf("wrong value; expected = %d; got = %d", 2999, v) + } + + // tag3 will tick thrice, tag2 will tick twice. + mockClock.Add(150 * time.Millisecond) + if v := mgr.GetTagInfo(id).Value; v != 2986 { + t.Fatalf("wrong value; expected = %d; got = %d", 2986, v) + } + + // tag3 will tick once, tag1 will tick once. + mockClock.Add(50 * time.Millisecond) + if v := mgr.GetTagInfo(id).Value; v != 2975 { + t.Fatalf("wrong value; expected = %d; got = %d", 2975, v) + } +} + +func TestMultiplePeers(t *testing.T) { + var ( + ids = []peer.ID{tu.RandPeerIDFatal(t), tu.RandPeerIDFatal(t), tu.RandPeerIDFatal(t)} + mgr, decay, mockClock = testDecayTracker(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayFixed(10), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + tag2, err := decay.RegisterDecayingTag("bop", 100*time.Millisecond, connmgr.DecayFixed(5), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + tag3, err := decay.RegisterDecayingTag("foo", 50*time.Millisecond, connmgr.DecayFixed(1), connmgr.BumpSumUnbounded()) + if err != nil { + t.Fatal(err) + } + + _ = tag1.Bump(ids[0], 1000) + _ = tag2.Bump(ids[0], 1000) + _ = tag3.Bump(ids[0], 1000) + + _ = tag1.Bump(ids[1], 500) + _ = tag2.Bump(ids[1], 500) + _ = tag3.Bump(ids[1], 500) + + _ = tag1.Bump(ids[2], 100) + _ = tag2.Bump(ids[2], 100) + _ = tag3.Bump(ids[2], 100) + + // allow the background goroutine to process bumps. + <-time.After(500 * time.Millisecond) + + mockClock.Add(3 * time.Second) + + // allow the background goroutine to process ticks. + <-time.After(500 * time.Millisecond) + + if v := mgr.GetTagInfo(ids[0]).Value; v != 2670 { + t.Fatalf("wrong value; expected = %d; got = %d", 2670, v) + } + + if v := mgr.GetTagInfo(ids[1]).Value; v != 1170 { + t.Fatalf("wrong value; expected = %d; got = %d", 1170, v) + } + + if v := mgr.GetTagInfo(ids[2]).Value; v != 40 { + t.Fatalf("wrong value; expected = %d; got = %d", 40, v) + } +} + +func TestLinearDecayOverwrite(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, mockClock = testDecayTracker(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayLinear(0.5), connmgr.BumpOverwrite()) + if err != nil { + t.Fatal(err) + } + + _ = tag1.Bump(id, 1000) + // allow the background goroutine to process bumps. + <-time.After(500 * time.Millisecond) + + mockClock.Add(250 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 500 { + t.Fatalf("value should be half; got = %d", v) + } + + mockClock.Add(250 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 250 { + t.Fatalf("value should be half; got = %d", v) + } + + _ = tag1.Bump(id, 1000) + // allow the background goroutine to process bumps. + <-time.After(500 * time.Millisecond) + + if v := mgr.GetTagInfo(id).Value; v != 1000 { + t.Fatalf("value should 1000; got = %d", v) + } +} + +func TestResolutionMisaligned(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, mockClock = testDecayTracker(t) + require = require.New(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", time.Duration(float64(TestResolution)*1.4), connmgr.DecayFixed(1), connmgr.BumpOverwrite()) + require.NoError(err) + + tag2, err := decay.RegisterDecayingTag("bop", time.Duration(float64(TestResolution)*2.4), connmgr.DecayFixed(1), connmgr.BumpOverwrite()) + require.NoError(err) + + _ = tag1.Bump(id, 1000) + _ = tag2.Bump(id, 1000) + // allow the background goroutine to process bumps. + <-time.After(500 * time.Millisecond) + + // first tick. + mockClock.Add(TestResolution) + require.Equal(1000, mgr.GetTagInfo(id).Tags["beep"]) + require.Equal(1000, mgr.GetTagInfo(id).Tags["bop"]) + + // next tick; tag1 would've ticked. + mockClock.Add(TestResolution) + require.Equal(999, mgr.GetTagInfo(id).Tags["beep"]) + require.Equal(1000, mgr.GetTagInfo(id).Tags["bop"]) + + // next tick; tag1 would've ticked twice, tag2 once. + mockClock.Add(TestResolution) + require.Equal(998, mgr.GetTagInfo(id).Tags["beep"]) + require.Equal(999, mgr.GetTagInfo(id).Tags["bop"]) + + require.Equal(1997, mgr.GetTagInfo(id).Value) +} + +func TestTagRemoval(t *testing.T) { + var ( + id1, id2 = tu.RandPeerIDFatal(t), tu.RandPeerIDFatal(t) + mgr, decay, mockClock = testDecayTracker(t) + require = require.New(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) + require.NoError(err) + + tag2, err := decay.RegisterDecayingTag("bop", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) + require.NoError(err) + + // id1 has both tags; id2 only has the first tag. + _ = tag1.Bump(id1, 1000) + _ = tag2.Bump(id1, 1000) + _ = tag1.Bump(id2, 1000) + + // allow the background goroutine to process bumps. + <-time.After(500 * time.Millisecond) + + // first tick. + mockClock.Add(TestResolution) + require.Equal(999, mgr.GetTagInfo(id1).Tags["beep"]) + require.Equal(999, mgr.GetTagInfo(id1).Tags["bop"]) + require.Equal(999, mgr.GetTagInfo(id2).Tags["beep"]) + + require.Equal(999*2, mgr.GetTagInfo(id1).Value) + require.Equal(999, mgr.GetTagInfo(id2).Value) + + // remove tag1 from p1. + err = tag1.Remove(id1) + + // allow the background goroutine to process the removal. + <-time.After(500 * time.Millisecond) + require.NoError(err) + + // next tick. both peers only have 1 tag, both at 998 value. + mockClock.Add(TestResolution) + require.Zero(mgr.GetTagInfo(id1).Tags["beep"]) + require.Equal(998, mgr.GetTagInfo(id1).Tags["bop"]) + require.Equal(998, mgr.GetTagInfo(id2).Tags["beep"]) + + require.Equal(998, mgr.GetTagInfo(id1).Value) + require.Equal(998, mgr.GetTagInfo(id2).Value) + + // remove tag1 from p1 again; no error. + err = tag1.Remove(id1) + require.NoError(err) +} + +func TestTagClosure(t *testing.T) { + var ( + id = tu.RandPeerIDFatal(t) + mgr, decay, mockClock = testDecayTracker(t) + require = require.New(t) + ) + + tag1, err := decay.RegisterDecayingTag("beep", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) + require.NoError(err) + + tag2, err := decay.RegisterDecayingTag("bop", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) + require.NoError(err) + + _ = tag1.Bump(id, 1000) + _ = tag2.Bump(id, 1000) + // allow the background goroutine to process bumps. + <-time.After(500 * time.Millisecond) + + // nothing has happened. + mockClock.Add(TestResolution) + require.Equal(999, mgr.GetTagInfo(id).Tags["beep"]) + require.Equal(999, mgr.GetTagInfo(id).Tags["bop"]) + require.Equal(999*2, mgr.GetTagInfo(id).Value) + + // next tick; tag1 would've ticked. + mockClock.Add(TestResolution) + require.Equal(998, mgr.GetTagInfo(id).Tags["beep"]) + require.Equal(998, mgr.GetTagInfo(id).Tags["bop"]) + require.Equal(998*2, mgr.GetTagInfo(id).Value) + + // close the tag. + err = tag1.Close() + require.NoError(err) + + // allow the background goroutine to process the closure. + <-time.After(500 * time.Millisecond) + require.Equal(998, mgr.GetTagInfo(id).Value) + + // a second closure should not error. + err = tag1.Close() + require.NoError(err) + + // bumping a tag after it's been closed should error. + err = tag1.Bump(id, 5) + require.Error(err) +} + +func testDecayTracker(tb testing.TB) (*BasicConnMgr, connmgr.Decayer, *clock.Mock) { + mockClock := clock.NewMock() + cfg := &DecayerCfg{ + Resolution: TestResolution, + Clock: mockClock, + } + + mgr, err := NewConnManager(10, 10, WithGracePeriod(time.Second), DecayerConfig(cfg)) + require.NoError(tb, err) + decay, ok := connmgr.SupportsDecay(mgr) + if !ok { + tb.Fatalf("connmgr does not support decay") + } + tb.Cleanup(func() { + mgr.Close() + decay.Close() + }) + + return mgr, decay, mockClock +} diff --git a/p2p/net/connmgr/options.go b/p2p/net/connmgr/options.go new file mode 100644 index 0000000000..76b4ef386d --- /dev/null +++ b/p2p/net/connmgr/options.go @@ -0,0 +1,53 @@ +package connmgr + +import ( + "errors" + "time" +) + +// config is the configuration struct for the basic connection manager. +type config struct { + highWater int + lowWater int + gracePeriod time.Duration + silencePeriod time.Duration + decayer *DecayerCfg + emergencyTrim bool +} + +// Option represents an option for the basic connection manager. +type Option func(*config) error + +// DecayerConfig applies a configuration for the decayer. +func DecayerConfig(opts *DecayerCfg) Option { + return func(cfg *config) error { + cfg.decayer = opts + return nil + } +} + +// WithGracePeriod sets the grace period. +// The grace period is the time a newly opened connection is given before it becomes +// subject to pruning. +func WithGracePeriod(p time.Duration) Option { + return func(cfg *config) error { + if p < 0 { + return errors.New("grace period must be non-negative") + } + cfg.gracePeriod = p + return nil + } +} + +// WithSilencePeriod sets the silence period. +// The connection manager will perform a cleanup once per silence period +// if the number of connections surpasses the high watermark. +func WithSilencePeriod(p time.Duration) Option { + return func(cfg *config) error { + if p <= 0 { + return errors.New("silence period must be non-zero") + } + cfg.silencePeriod = p + return nil + } +} diff --git a/p2p/net/connmgr/watchdog_cgo.go b/p2p/net/connmgr/watchdog_cgo.go new file mode 100644 index 0000000000..b048111f36 --- /dev/null +++ b/p2p/net/connmgr/watchdog_cgo.go @@ -0,0 +1,18 @@ +//go:build cgo +// +build cgo + +package connmgr + +import "github.com/raulk/go-watchdog" + +func registerWatchdog(cb func()) (unregister func()) { + return watchdog.RegisterPostGCNotifee(cb) +} + +// WithEmergencyTrim is an option to enable trimming connections on memory emergency. +func WithEmergencyTrim(enable bool) Option { + return func(cfg *config) error { + cfg.emergencyTrim = enable + return nil + } +} diff --git a/p2p/net/connmgr/watchdog_no_cgo.go b/p2p/net/connmgr/watchdog_no_cgo.go new file mode 100644 index 0000000000..fb83c286de --- /dev/null +++ b/p2p/net/connmgr/watchdog_no_cgo.go @@ -0,0 +1,16 @@ +//go:build !cgo +// +build !cgo + +package connmgr + +func registerWatchdog(func()) (unregister func()) { + return nil +} + +// WithEmergencyTrim is an option to enable trimming connections on memory emergency. +func WithEmergencyTrim(enable bool) Option { + return func(cfg *config) error { + log.Warn("platform doesn't support go-watchdog") + return nil + } +}