From e89814c520f06f92eeaa4602ea24b5ad626e4cef Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Fri, 2 Jun 2023 13:26:28 -0700 Subject: [PATCH] swarm: Dedup addresses to dial (#2322) * Dedup addresses to dial Co-authored-by: Aayush Rajasekaran * Move DedupAddrs test * Typo --------- Co-authored-by: Aayush Rajasekaran --- core/network/network.go | 22 +++++++++++++++ core/network/network_test.go | 36 +++++++++++++++++++++++++ p2p/host/basic/basic_host.go | 26 ++---------------- p2p/host/basic/basic_host_test.go | 26 ------------------ p2p/net/swarm/dial_worker.go | 14 ++++++++-- p2p/net/swarm/swarm_dial.go | 1 + p2p/net/swarm/swarm_dial_test.go | 45 +++++++++++++++++++++++++++++++ 7 files changed, 118 insertions(+), 52 deletions(-) create mode 100644 core/network/network_test.go diff --git a/core/network/network.go b/core/network/network.go index 0beaac0f7..47908b8e3 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -6,8 +6,10 @@ package network import ( + "bytes" "context" "io" + "sort" "time" "github.com/libp2p/go-libp2p/core/peer" @@ -184,3 +186,23 @@ type Dialer interface { Notify(Notifiee) StopNotify(Notifiee) } + +// DedupAddrs deduplicates addresses in place, leave only unique addresses. +// It doesn't allocate. +func DedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { + if len(addrs) == 0 { + return addrs + } + sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) < 0 }) + idx := 1 + for i := 1; i < len(addrs); i++ { + if !addrs[i-1].Equal(addrs[i]) { + addrs[idx] = addrs[i] + idx++ + } + } + for i := idx; i < len(addrs); i++ { + addrs[i] = nil + } + return addrs[:idx] +} diff --git a/core/network/network_test.go b/core/network/network_test.go new file mode 100644 index 000000000..a78e6d704 --- /dev/null +++ b/core/network/network_test.go @@ -0,0 +1,36 @@ +package network + +import ( + "fmt" + "testing" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func TestDedupAddrs(t *testing.T) { + tcpAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234") + quicAddr := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1") + wsAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234/ws") + + type testcase struct { + in, out []ma.Multiaddr + } + + for i, tc := range []testcase{ + {in: nil, out: nil}, + {in: []ma.Multiaddr{tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, + {in: []ma.Multiaddr{tcpAddr, tcpAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, + {in: []ma.Multiaddr{tcpAddr, quicAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr}}, + {in: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}}, + } { + tc := tc + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + deduped := DedupAddrs(tc.in) + for _, a := range tc.out { + require.Contains(t, deduped, a) + } + }) + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 5b8fa4595..d0a6c9cba 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -1,13 +1,11 @@ package basichost import ( - "bytes" "context" "errors" "fmt" "io" "net" - "sort" "sync" "time" @@ -813,26 +811,6 @@ func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { return addr } -// dedupAddrs deduplicates addresses in place, leave only unique addresses. -// It doesn't allocate. -func dedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - if len(addrs) == 0 { - return addrs - } - sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) < 0 }) - idx := 1 - for i := 1; i < len(addrs); i++ { - if !addrs[i-1].Equal(addrs[i]) { - addrs[idx] = addrs[i] - idx++ - } - } - for i := idx; i < len(addrs); i++ { - addrs[i] = nil - } - return addrs[:idx] -} - // AllAddrs returns all the addresses of BasicHost at this moment in time. // It's ok to not include addresses if they're not available to be used now. func (h *BasicHost) AllAddrs() []ma.Multiaddr { @@ -857,7 +835,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { finalAddrs = append(finalAddrs, resolved...) } - finalAddrs = dedupAddrs(finalAddrs) + finalAddrs = network.DedupAddrs(finalAddrs) // natmgr is nil if we do not use nat option; if h.natmgr != nil { @@ -927,7 +905,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { } finalAddrs = append(finalAddrs, observedAddrs...) } - finalAddrs = dedupAddrs(finalAddrs) + finalAddrs = network.DedupAddrs(finalAddrs) finalAddrs = inferWebtransportAddrsFromQuic(finalAddrs) return finalAddrs diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 19ea2bd02..5c1babd9d 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -825,32 +825,6 @@ func TestNormalizeMultiaddr(t *testing.T) { require.Equal(t, "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport", h1.NormalizeMultiaddr(ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")).String()) } -func TestDedupAddrs(t *testing.T) { - tcpAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234") - quicAddr := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1") - wsAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234/ws") - - type testcase struct { - in, out []ma.Multiaddr - } - - for i, tc := range []testcase{ - {in: nil, out: nil}, - {in: []ma.Multiaddr{tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, - {in: []ma.Multiaddr{tcpAddr, tcpAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, - {in: []ma.Multiaddr{tcpAddr, quicAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr}}, - {in: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}}, - } { - tc := tc - t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { - deduped := dedupAddrs(tc.in) - for _, a := range tc.out { - require.Contains(t, deduped, a) - } - }) - } -} - func TestInferWebtransportAddrsFromQuic(t *testing.T) { type testCase struct { name string diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index ba7ba87d4..a319d00e5 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -177,7 +177,12 @@ loop: case <-w.triggerDial: for _, addr := range w.nextDial { // spawn the dial - ad := w.pending[string(addr.Bytes())] + ad, ok := w.pending[string(addr.Bytes())] + if !ok { + log.Warn("unexpectedly missing pending addrDial for addr") + // Assume nothing to dial here + continue + } err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) if err != nil { w.dispatchError(ad, err) @@ -192,7 +197,12 @@ loop: w.connected = true } - ad := w.pending[string(res.Addr.Bytes())] + ad, ok := w.pending[string(res.Addr.Bytes())] + if !ok { + log.Warn("unexpectedly missing pending addrDial res") + // Assume nothing to do here + continue + } if res.Conn != nil { // we got a connection, add it to the swarm diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 5423a199b..49c0fc7fd 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -334,6 +334,7 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) } + goodAddrs = network.DedupAddrs(goodAddrs) if len(goodAddrs) == 0 { return nil, ErrNoGoodAddresses diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 566a2307f..215ee6df9 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -65,6 +65,51 @@ func TestAddrsForDial(t *testing.T) { require.NotZero(t, len(mas)) } +func TestDedupAddrsForDial(t *testing.T) { + mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)} + ipaddr, err := net.ResolveIPAddr("ip4", "1.2.3.4") + if err != nil { + t.Fatal(err) + } + mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr} + + resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver)) + if err != nil { + t.Fatal(err) + } + + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + + ps, err := pstoremem.NewPeerstore() + require.NoError(t, err) + ps.AddPubKey(id, priv.GetPublic()) + ps.AddPrivKey(id, priv) + t.Cleanup(func() { ps.Close() }) + + s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver)) + require.NoError(t, err) + defer s.Close() + + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + require.NoError(t, err) + err = s.AddTransport(tpt) + require.NoError(t, err) + + otherPeer := test.RandPeerIDFatal(t) + + ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234"), time.Hour) + ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour) + + ctx := context.Background() + mas, err := s.addrsForDial(ctx, otherPeer) + require.NoError(t, err) + + require.Equal(t, 1, len(mas)) +} + func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { priv, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err)