From efb73034b24f41894816864ba452af561edc3efc Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 8 Sep 2021 19:35:04 +0100 Subject: [PATCH] add a Close method, remove the context from the constructor --- relay.go | 13 ++++---- relay_test.go | 78 +++++++++++++++++------------------------------ transport.go | 12 ++++++-- transport_test.go | 24 ++++++--------- 4 files changed, 54 insertions(+), 73 deletions(-) diff --git a/relay.go b/relay.go index 8db4e80..8c7f46d 100644 --- a/relay.go +++ b/relay.go @@ -41,10 +41,11 @@ var ( // Relay is the relay transport and service. type Relay struct { - host host.Host - upgrader *tptu.Upgrader - ctx context.Context - self peer.ID + host host.Host + upgrader *tptu.Upgrader + ctx context.Context + ctxCancel context.CancelFunc + self peer.ID active bool hop bool @@ -93,15 +94,15 @@ func (e RelayError) Error() string { } // NewRelay constructs a new relay. -func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) { +func NewRelay(h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) { r := &Relay{ upgrader: upgrader, host: h, - ctx: ctx, self: h.ID(), incoming: make(chan *Conn), hopCount: make(map[peer.ID]int), } + r.ctx, r.ctxCancel = context.WithCancel(context.Background()) for _, opt := range opts { switch opt { diff --git a/relay_test.go b/relay_test.go index 1fcc54b..8f397a9 100644 --- a/relay_test.go +++ b/relay_test.go @@ -38,14 +38,13 @@ func getNetHosts(t *testing.T, n int) []host.Host { netw := swarmt.GenSwarm(t) h := bhost.NewBlankHost(netw) out = append(out, h) - t.Cleanup(func() { h.Close() }) } return out } -func newTestRelay(t *testing.T, ctx context.Context, host host.Host, opts ...RelayOpt) *Relay { - r, err := NewRelay(ctx, host, swarmt.GenUpgrader(host.Network().(*swarm.Swarm)), opts...) +func newTestRelay(t *testing.T, host host.Host, opts ...RelayOpt) *Relay { + r, err := NewRelay(host, swarmt.GenUpgrader(host.Network().(*swarm.Swarm)), opts...) if err != nil { t.Fatal(err) } @@ -71,11 +70,11 @@ func TestBasicRelay(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) + r1 := newTestRelay(t, hosts[0]) - newTestRelay(t, ctx, hosts[1], OptHop) + newTestRelay(t, hosts[1], OptHop) - r3 := newTestRelay(t, ctx, hosts[2]) + r3 := newTestRelay(t, hosts[2]) var ( conn1, conn2 net.Conn @@ -145,11 +144,11 @@ func TestRelayReset(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) + r1 := newTestRelay(t, hosts[0]) - newTestRelay(t, ctx, hosts[1], OptHop) + newTestRelay(t, hosts[1], OptHop) - r3 := newTestRelay(t, ctx, hosts[2]) + r3 := newTestRelay(t, hosts[2]) ready := make(chan struct{}) @@ -203,10 +202,10 @@ func TestBasicRelayDial(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) + r1 := newTestRelay(t, hosts[0]) - _ = newTestRelay(t, ctx, hosts[1], OptHop) - r3 := newTestRelay(t, ctx, hosts[2]) + _ = newTestRelay(t, hosts[1], OptHop) + r3 := newTestRelay(t, hosts[2]) var ( conn1, conn2 net.Conn @@ -266,49 +265,28 @@ func TestBasicRelayDial(t *testing.T) { } func TestUnspecificRelayDialFails(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - hosts := getNetHosts(t, 3) - r1 := newTestRelay(t, ctx, hosts[0]) - - newTestRelay(t, ctx, hosts[1], OptHop) - - r3 := newTestRelay(t, ctx, hosts[2]) + r1 := newTestRelay(t, hosts[0]) + newTestRelay(t, hosts[1], OptHop) + r3 := newTestRelay(t, hosts[2]) connect(t, hosts[0], hosts[1]) connect(t, hosts[1], hosts[2]) time.Sleep(100 * time.Millisecond) - var ( - done = make(chan struct{}) - ) - - defer func() { - cancel() - <-done - }() - go func() { - defer close(done) - list := r3.Listener() - - var err error - _, err = list.Accept() - if err == nil { + if _, err := r3.Listener().Accept(); err == nil { t.Error("should not have received relay connection") } }() addr := ma.StringCast("/p2p-circuit") - rctx, rcancel := context.WithTimeout(ctx, time.Second) - defer rcancel() - - var err error - _, err = r1.Dial(rctx, addr, hosts[2].ID()) - if err == nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if _, err := r1.Dial(ctx, addr, hosts[2].ID()); err == nil { t.Fatal("expected dial with unspecified relay address to fail, even if we're connected to a relay") } } @@ -324,11 +302,11 @@ func TestRelayThroughNonHop(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) + r1 := newTestRelay(t, hosts[0]) - newTestRelay(t, ctx, hosts[1]) + newTestRelay(t, hosts[1]) - newTestRelay(t, ctx, hosts[2]) + newTestRelay(t, hosts[2]) rinfo := hosts[1].Peerstore().PeerInfo(hosts[1].ID()) dinfo := hosts[2].Peerstore().PeerInfo(hosts[2].ID()) @@ -361,9 +339,9 @@ func TestRelayNoDestConnection(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) + r1 := newTestRelay(t, hosts[0]) - newTestRelay(t, ctx, hosts[1], OptHop) + newTestRelay(t, hosts[1], OptHop) rinfo := hosts[1].Peerstore().PeerInfo(hosts[1].ID()) dinfo := hosts[2].Peerstore().PeerInfo(hosts[2].ID()) @@ -396,9 +374,9 @@ func TestActiveRelay(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) - newTestRelay(t, ctx, hosts[1], OptHop, OptActive) - r3 := newTestRelay(t, ctx, hosts[2]) + r1 := newTestRelay(t, hosts[0]) + newTestRelay(t, hosts[1], OptHop, OptActive) + r3 := newTestRelay(t, hosts[2]) connChan := make(chan manet.Conn) @@ -458,9 +436,9 @@ func TestRelayCanHop(t *testing.T) { time.Sleep(10 * time.Millisecond) - r1 := newTestRelay(t, ctx, hosts[0]) + r1 := newTestRelay(t, hosts[0]) - newTestRelay(t, ctx, hosts[1], OptHop) + newTestRelay(t, hosts[1], OptHop) canhop, err := r1.CanHop(ctx, hosts[1].ID()) if err != nil { diff --git a/transport.go b/transport.go index 03b32db..a064a9f 100644 --- a/transport.go +++ b/transport.go @@ -1,8 +1,8 @@ package relay import ( - "context" "fmt" + "io" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/transport" @@ -13,6 +13,7 @@ import ( var circuitAddr = ma.Cast(ma.ProtocolWithCode(ma.P_CIRCUIT).VCode) var _ transport.Transport = (*RelayTransport)(nil) +var _ io.Closer = (*RelayTransport)(nil) type RelayTransport Relay @@ -45,14 +46,19 @@ func (t *RelayTransport) Protocols() []int { return []int{ma.P_CIRCUIT} } +func (r *RelayTransport) Close() error { + r.ctxCancel() + return nil +} + // AddRelayTransport constructs a relay and adds it as a transport to the host network. -func AddRelayTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) error { +func AddRelayTransport(h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) error { n, ok := h.Network().(transport.TransportNetwork) if !ok { return fmt.Errorf("%v is not a transport network", h.Network()) } - r, err := NewRelay(ctx, h, upgrader, opts...) + r, err := NewRelay(h, upgrader, opts...) if err != nil { return err } diff --git a/transport_test.go b/transport_test.go index 0b2f3a2..55f0cfe 100644 --- a/transport_test.go +++ b/transport_test.go @@ -23,20 +23,20 @@ const TestProto = "test/relay-transport" var msg = []byte("relay works!") -func testSetupRelay(t *testing.T, ctx context.Context) []host.Host { +func testSetupRelay(t *testing.T) []host.Host { hosts := getNetHosts(t, 3) - err := AddRelayTransport(ctx, hosts[0], swarmt.GenUpgrader(hosts[0].Network().(*swarm.Swarm))) + err := AddRelayTransport(hosts[0], swarmt.GenUpgrader(hosts[0].Network().(*swarm.Swarm))) if err != nil { t.Fatal(err) } - err = AddRelayTransport(ctx, hosts[1], swarmt.GenUpgrader(hosts[1].Network().(*swarm.Swarm)), OptHop) + err = AddRelayTransport(hosts[1], swarmt.GenUpgrader(hosts[1].Network().(*swarm.Swarm)), OptHop) if err != nil { t.Fatal(err) } - err = AddRelayTransport(ctx, hosts[2], swarmt.GenUpgrader(hosts[2].Network().(*swarm.Swarm))) + err = AddRelayTransport(hosts[2], swarmt.GenUpgrader(hosts[2].Network().(*swarm.Swarm))) if err != nil { t.Fatal(err) } @@ -60,10 +60,7 @@ func testSetupRelay(t *testing.T, ctx context.Context) []host.Host { } func TestFullAddressTransportDial(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - hosts := testSetupRelay(t, ctx) + hosts := testSetupRelay(t) var relayAddr ma.Multiaddr for _, addr := range hosts[1].Addrs() { @@ -78,12 +75,11 @@ func TestFullAddressTransportDial(t *testing.T) { t.Fatal(err) } - rctx, rcancel := context.WithTimeout(ctx, time.Second) - defer rcancel() - hosts[0].Peerstore().AddAddrs(hosts[2].ID(), []ma.Multiaddr{addr}, peerstore.TempAddrTTL) - s, err := hosts[0].NewStream(rctx, hosts[2].ID(), TestProto) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + s, err := hosts[0].NewStream(ctx, hosts[2].ID(), TestProto) if err != nil { t.Fatal(err) } @@ -102,7 +98,7 @@ func TestSpecificRelayTransportDial(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hosts := testSetupRelay(t, ctx) + hosts := testSetupRelay(t) addr, err := ma.NewMultiaddr(fmt.Sprintf("/ipfs/%s/p2p-circuit/ipfs/%s", hosts[1].ID().Pretty(), hosts[2].ID().Pretty())) if err != nil { @@ -133,7 +129,7 @@ func TestUnspecificRelayTransportDialFails(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hosts := testSetupRelay(t, ctx) + hosts := testSetupRelay(t) addr, err := ma.NewMultiaddr(fmt.Sprintf("/p2p-circuit/ipfs/%s", hosts[2].ID().Pretty())) if err != nil {