diff --git a/p2p/host/autonat/svc.go b/p2p/host/autonat/svc.go index bdde08928a..1be5e228b6 100644 --- a/p2p/host/autonat/svc.go +++ b/p2p/host/autonat/svc.go @@ -1,6 +1,7 @@ package autonat import ( + "bytes" "context" "errors" "math/rand" @@ -123,7 +124,12 @@ func (as *autoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Me } if as.config.dialPolicy.skipDial(addr) { - continue + err, newobsaddr := patchObsaddr(addr, obsaddr) + if err == nil { + addr = newobsaddr + } else { + continue + } } if ip, err := manet.ToIP(addr); err != nil || !obsHost.Equal(ip) { @@ -228,3 +234,61 @@ func (as *autoNATService) background(ctx context.Context) { } } } + +// patchObsaddr replaces obsaddr's port number with the port number of `localaddr` +func patchObsaddr(localaddr, obsaddr ma.Multiaddr) (error, ma.Multiaddr) { + if localaddr == nil || obsaddr == nil { + return errors.New("localaddr and obsaddr can't be nil"), nil + } + var rawport []byte + var code int + var newc ma.Component + isValid := false + ma.ForEach(localaddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_UDP, ma.P_TCP: + code = c.Protocol().Code + rawport = c.RawValue() + newc = c + return !isValid + case ma.P_IP4, ma.P_IP6: + isValid = true + } + return true + }) + + if isValid == true && len(rawport) > 0 { + obsbytes := obsaddr.Bytes() + obsoffset := 0 + isObsValid := false + isReplaced := false + var buffer bytes.Buffer + ma.ForEach(obsaddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_UDP, ma.P_TCP: + if code == c.Protocol().Code && isObsValid == true { //obsaddr has the same type protocol, and we can replace it. + if bytes.Compare(rawport, c.RawValue()) != 0 { + buffer.Write(obsbytes[:obsoffset]) + buffer.Write(newc.Bytes()) + tail := obsoffset + len(c.Bytes()) + if len(obsbytes)-tail > 0 { + buffer.Write(obsbytes[tail:]) + } + isReplaced = true + } + return false + } + case ma.P_IP4, ma.P_IP6: + isObsValid = true + } + obsoffset += len(c.Bytes()) + return true + }) + if isReplaced == true { + newobsaddr, err := ma.NewMultiaddrBytes(buffer.Bytes()) + return err, newobsaddr + } + } + + return errors.New("only same protocol address can be patched."), nil +} diff --git a/p2p/host/autonat/svc_test.go b/p2p/host/autonat/svc_test.go index e1df6c8a8c..3faa8cbee7 100644 --- a/p2p/host/autonat/svc_test.go +++ b/p2p/host/autonat/svc_test.go @@ -207,3 +207,25 @@ func TestAutoNATServiceStartup(t *testing.T) { t.Fatalf("autonat should report public, but didn't") } } + +func TestMultiaddrPatchSuccess(t *testing.T) { + m1, _ := ma.NewMultiaddr("/ip4/192.168.0.10/tcp/64555") + m2, _ := ma.NewMultiaddr("/ip4/72.53.243.114/tcp/19005") + correctm2, _ := ma.NewMultiaddr("/ip4/72.53.243.114/tcp/64555") + err, newm2 := patchObsaddr(m1, m2) + if err != nil { + t.Fatalf("patchObsaddr failed, was %s error %s", m2, err) + } + if newm2.Equal(correctm2) == false { + t.Fatalf("patchObsaddr success, but new obsaddr is %s should be %s", newm2, correctm2) + } +} + +func TestMultiaddrPatchError(t *testing.T) { + m1, _ := ma.NewMultiaddr("/ip4/192.168.0.10/tcp/64555") + m2, _ := ma.NewMultiaddr("/ip4/72.53.243.114/udp/19005") + err, newm2 := patchObsaddr(m1, m2) + if err == nil { + t.Fatalf("this address should not be patched, new address: %s", newm2) + } +}