From 94628cf4a665d2a921a3eeceb627ba6bfd7be3fb Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 17 Jul 2024 04:10:44 -0700 Subject: [PATCH] Make it safe to roundtrip SplitXXX and Join (#250) --- multiaddr_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++ util.go | 9 +++++++ 2 files changed, 72 insertions(+) diff --git a/multiaddr_test.go b/multiaddr_test.go index 6d160b4..eda478f 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -974,3 +974,66 @@ func TestHTTPPath(t *testing.T) { require.Equal(t, "tmp/bar", string(component.RawValue())) }) } + +func FuzzSplitRoundtrip(f *testing.F) { + for _, v := range good { + f.Add(v) + } + otherMultiaddr := StringCast("/udp/1337") + + f.Fuzz(func(t *testing.T, addrStr string) { + addr, err := NewMultiaddr(addrStr) + if err != nil { + t.Skip() // Skip inputs that are not valid multiaddrs + } + + // Test SplitFirst + first, rest := SplitFirst(addr) + joined := Join(first, rest) + require.Equal(t, addr, joined, "SplitFirst and Join should round-trip") + + // Test SplitLast + rest, last := SplitLast(addr) + joined = Join(rest, last) + require.Equal(t, addr, joined, "SplitLast and Join should round-trip") + + p := addr.Protocols() + if len(p) == 0 { + t.Skip() + } + + tryPubMethods := func(a Multiaddr) { + if a == nil { + return + } + _ = a.Equal(otherMultiaddr) + _ = a.Bytes() + _ = a.String() + _ = a.Protocols() + _ = a.Encapsulate(otherMultiaddr) + _ = a.Decapsulate(otherMultiaddr) + _, _ = a.ValueForProtocol(P_TCP) + } + + for _, proto := range p { + splitFunc := func(c Component) bool { + return c.Protocol().Code == proto.Code + } + beforeC, after := SplitFirst(addr) + joined = Join(beforeC, after) + require.Equal(t, addr, joined) + tryPubMethods(after) + + before, afterC := SplitLast(addr) + joined = Join(before, afterC) + require.Equal(t, addr, joined) + tryPubMethods(before) + + before, after = SplitFunc(addr, splitFunc) + joined = Join(before, after) + require.Equal(t, addr, joined) + tryPubMethods(before) + tryPubMethods(after) + } + }) +} diff --git a/util.go b/util.go index b0ac7ee..63abbcc 100644 --- a/util.go +++ b/util.go @@ -28,12 +28,21 @@ func Join(ms ...Multiaddr) Multiaddr { length := 0 for _, m := range ms { + if m == nil { + continue + } length += len(m.Bytes()) } bidx := 0 b := make([]byte, length) + if length == 0 { + return nil + } for _, mb := range ms { + if mb == nil { + continue + } bidx += copy(b[bidx:], mb.Bytes()) } return &multiaddr{bytes: b}