Skip to content
This repository has been archived by the owner on Sep 9, 2022. It is now read-only.

Commit

Permalink
add a Close method, remove the context from the constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 8, 2021
1 parent 0fdb26f commit efb7303
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 73 deletions.
13 changes: 7 additions & 6 deletions relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ var (

// Relay is the relay transport and service.
type Relay struct {
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
self peer.ID
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
ctxCancel context.CancelFunc
self peer.ID

active bool
hop bool
Expand Down Expand Up @@ -93,15 +94,15 @@ func (e RelayError) Error() string {
}

// NewRelay constructs a new relay.
func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) {
func NewRelay(h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) (*Relay, error) {
r := &Relay{
upgrader: upgrader,
host: h,
ctx: ctx,
self: h.ID(),
incoming: make(chan *Conn),
hopCount: make(map[peer.ID]int),
}
r.ctx, r.ctxCancel = context.WithCancel(context.Background())

for _, opt := range opts {
switch opt {
Expand Down
78 changes: 28 additions & 50 deletions relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ func getNetHosts(t *testing.T, n int) []host.Host {
netw := swarmt.GenSwarm(t)
h := bhost.NewBlankHost(netw)
out = append(out, h)
t.Cleanup(func() { h.Close() })
}

return out
}

func newTestRelay(t *testing.T, ctx context.Context, host host.Host, opts ...RelayOpt) *Relay {
r, err := NewRelay(ctx, host, swarmt.GenUpgrader(host.Network().(*swarm.Swarm)), opts...)
func newTestRelay(t *testing.T, host host.Host, opts ...RelayOpt) *Relay {
r, err := NewRelay(host, swarmt.GenUpgrader(host.Network().(*swarm.Swarm)), opts...)
if err != nil {
t.Fatal(err)
}
Expand All @@ -71,11 +70,11 @@ func TestBasicRelay(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

r3 := newTestRelay(t, ctx, hosts[2])
r3 := newTestRelay(t, hosts[2])

var (
conn1, conn2 net.Conn
Expand Down Expand Up @@ -145,11 +144,11 @@ func TestRelayReset(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

r3 := newTestRelay(t, ctx, hosts[2])
r3 := newTestRelay(t, hosts[2])

ready := make(chan struct{})

Expand Down Expand Up @@ -203,10 +202,10 @@ func TestBasicRelayDial(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

_ = newTestRelay(t, ctx, hosts[1], OptHop)
r3 := newTestRelay(t, ctx, hosts[2])
_ = newTestRelay(t, hosts[1], OptHop)
r3 := newTestRelay(t, hosts[2])

var (
conn1, conn2 net.Conn
Expand Down Expand Up @@ -266,49 +265,28 @@ func TestBasicRelayDial(t *testing.T) {
}

func TestUnspecificRelayDialFails(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

hosts := getNetHosts(t, 3)

r1 := newTestRelay(t, ctx, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)

r3 := newTestRelay(t, ctx, hosts[2])
r1 := newTestRelay(t, hosts[0])
newTestRelay(t, hosts[1], OptHop)
r3 := newTestRelay(t, hosts[2])

connect(t, hosts[0], hosts[1])
connect(t, hosts[1], hosts[2])

time.Sleep(100 * time.Millisecond)

var (
done = make(chan struct{})
)

defer func() {
cancel()
<-done
}()

go func() {
defer close(done)
list := r3.Listener()

var err error
_, err = list.Accept()
if err == nil {
if _, err := r3.Listener().Accept(); err == nil {
t.Error("should not have received relay connection")
}
}()

addr := ma.StringCast("/p2p-circuit")

rctx, rcancel := context.WithTimeout(ctx, time.Second)
defer rcancel()

var err error
_, err = r1.Dial(rctx, addr, hosts[2].ID())
if err == nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if _, err := r1.Dial(ctx, addr, hosts[2].ID()); err == nil {
t.Fatal("expected dial with unspecified relay address to fail, even if we're connected to a relay")
}
}
Expand All @@ -324,11 +302,11 @@ func TestRelayThroughNonHop(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1])
newTestRelay(t, hosts[1])

newTestRelay(t, ctx, hosts[2])
newTestRelay(t, hosts[2])

rinfo := hosts[1].Peerstore().PeerInfo(hosts[1].ID())
dinfo := hosts[2].Peerstore().PeerInfo(hosts[2].ID())
Expand Down Expand Up @@ -361,9 +339,9 @@ func TestRelayNoDestConnection(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

rinfo := hosts[1].Peerstore().PeerInfo(hosts[1].ID())
dinfo := hosts[2].Peerstore().PeerInfo(hosts[2].ID())
Expand Down Expand Up @@ -396,9 +374,9 @@ func TestActiveRelay(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
newTestRelay(t, ctx, hosts[1], OptHop, OptActive)
r3 := newTestRelay(t, ctx, hosts[2])
r1 := newTestRelay(t, hosts[0])
newTestRelay(t, hosts[1], OptHop, OptActive)
r3 := newTestRelay(t, hosts[2])

connChan := make(chan manet.Conn)

Expand Down Expand Up @@ -458,9 +436,9 @@ func TestRelayCanHop(t *testing.T) {

time.Sleep(10 * time.Millisecond)

r1 := newTestRelay(t, ctx, hosts[0])
r1 := newTestRelay(t, hosts[0])

newTestRelay(t, ctx, hosts[1], OptHop)
newTestRelay(t, hosts[1], OptHop)

canhop, err := r1.CanHop(ctx, hosts[1].ID())
if err != nil {
Expand Down
12 changes: 9 additions & 3 deletions transport.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package relay

import (
"context"
"fmt"
"io"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/transport"
Expand All @@ -13,6 +13,7 @@ import (
var circuitAddr = ma.Cast(ma.ProtocolWithCode(ma.P_CIRCUIT).VCode)

var _ transport.Transport = (*RelayTransport)(nil)
var _ io.Closer = (*RelayTransport)(nil)

type RelayTransport Relay

Expand Down Expand Up @@ -45,14 +46,19 @@ func (t *RelayTransport) Protocols() []int {
return []int{ma.P_CIRCUIT}
}

func (r *RelayTransport) Close() error {
r.ctxCancel()
return nil
}

// AddRelayTransport constructs a relay and adds it as a transport to the host network.
func AddRelayTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) error {
func AddRelayTransport(h host.Host, upgrader *tptu.Upgrader, opts ...RelayOpt) error {
n, ok := h.Network().(transport.TransportNetwork)
if !ok {
return fmt.Errorf("%v is not a transport network", h.Network())
}

r, err := NewRelay(ctx, h, upgrader, opts...)
r, err := NewRelay(h, upgrader, opts...)
if err != nil {
return err
}
Expand Down
24 changes: 10 additions & 14 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ const TestProto = "test/relay-transport"

var msg = []byte("relay works!")

func testSetupRelay(t *testing.T, ctx context.Context) []host.Host {
func testSetupRelay(t *testing.T) []host.Host {
hosts := getNetHosts(t, 3)

err := AddRelayTransport(ctx, hosts[0], swarmt.GenUpgrader(hosts[0].Network().(*swarm.Swarm)))
err := AddRelayTransport(hosts[0], swarmt.GenUpgrader(hosts[0].Network().(*swarm.Swarm)))
if err != nil {
t.Fatal(err)
}

err = AddRelayTransport(ctx, hosts[1], swarmt.GenUpgrader(hosts[1].Network().(*swarm.Swarm)), OptHop)
err = AddRelayTransport(hosts[1], swarmt.GenUpgrader(hosts[1].Network().(*swarm.Swarm)), OptHop)
if err != nil {
t.Fatal(err)
}

err = AddRelayTransport(ctx, hosts[2], swarmt.GenUpgrader(hosts[2].Network().(*swarm.Swarm)))
err = AddRelayTransport(hosts[2], swarmt.GenUpgrader(hosts[2].Network().(*swarm.Swarm)))
if err != nil {
t.Fatal(err)
}
Expand All @@ -60,10 +60,7 @@ func testSetupRelay(t *testing.T, ctx context.Context) []host.Host {
}

func TestFullAddressTransportDial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := testSetupRelay(t, ctx)
hosts := testSetupRelay(t)

var relayAddr ma.Multiaddr
for _, addr := range hosts[1].Addrs() {
Expand All @@ -78,12 +75,11 @@ func TestFullAddressTransportDial(t *testing.T) {
t.Fatal(err)
}

rctx, rcancel := context.WithTimeout(ctx, time.Second)
defer rcancel()

hosts[0].Peerstore().AddAddrs(hosts[2].ID(), []ma.Multiaddr{addr}, peerstore.TempAddrTTL)

s, err := hosts[0].NewStream(rctx, hosts[2].ID(), TestProto)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
s, err := hosts[0].NewStream(ctx, hosts[2].ID(), TestProto)
if err != nil {
t.Fatal(err)
}
Expand All @@ -102,7 +98,7 @@ func TestSpecificRelayTransportDial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := testSetupRelay(t, ctx)
hosts := testSetupRelay(t)

addr, err := ma.NewMultiaddr(fmt.Sprintf("/ipfs/%s/p2p-circuit/ipfs/%s", hosts[1].ID().Pretty(), hosts[2].ID().Pretty()))
if err != nil {
Expand Down Expand Up @@ -133,7 +129,7 @@ func TestUnspecificRelayTransportDialFails(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := testSetupRelay(t, ctx)
hosts := testSetupRelay(t)

addr, err := ma.NewMultiaddr(fmt.Sprintf("/p2p-circuit/ipfs/%s", hosts[2].ID().Pretty()))
if err != nil {
Expand Down

0 comments on commit efb7303

Please sign in to comment.