From b42ba0faf35fa29686e111c6de3c7bcb5d235212 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 14 May 2020 04:54:10 -0700 Subject: [PATCH] various identify fixes and nits (#922) * various identify fixes and nits Co-authored-by: Aarsh Shah --- p2p/protocol/identify/id.go | 76 +++++++++++------------ p2p/protocol/identify/id_test.go | 2 - p2p/protocol/identify/peer_loop.go | 80 +++++++++++-------------- p2p/protocol/identify/peer_loop_test.go | 44 ++++---------- 4 files changed, 84 insertions(+), 118 deletions(-) diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 822ee83dd..f9c37bf15 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -60,10 +60,8 @@ func init() { const transientTTL = 10 * time.Second type addPeerHandlerReq struct { - rp peer.ID - localConnAddr ma.Multiaddr - remoteConnAddr ma.Multiaddr - resp chan *peerHandler + rp peer.ID + resp chan *peerHandler } type rmPeerHandlerReq struct { @@ -194,9 +192,7 @@ func (ids *IDService) loop() { } if ids.Host.Network().Connectedness(rp) == network.Connected { - mes := &pb.Identify{} - ids.populateMessage(mes, rp, addReq.localConnAddr, addReq.remoteConnAddr) - ph = newPeerHandler(rp, ids, mes) + ph = newPeerHandler(rp, ids) ph.start() phs[rp] = ph addReq.resp <- ph @@ -378,7 +374,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) { defer func() { helpers.FullClose(s) if ph != nil { - ph.msgMu.RUnlock() + ph.snapshotMu.RUnlock() } }() @@ -386,8 +382,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) { phCh := make(chan *peerHandler, 1) select { - case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), c.LocalMultiaddr(), - c.RemoteMultiaddr(), phCh}: + case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), phCh}: case <-ids.ctx.Done(): return } @@ -398,9 +393,11 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) { return } - ph.msgMu.RLock() + ph.snapshotMu.RLock() + mes := &pb.Identify{} + ids.populateMessage(mes, c, ph.snapshot) w := ggio.NewDelimitedWriter(s) - w.WriteMsg(ph.idMsgSnapshot) + w.WriteMsg(mes) log.Debugf("%s sent message to %s %s", ID, c.RemotePeer(), c.RemoteMultiaddr()) } @@ -422,13 +419,29 @@ func (ids *IDService) handleIdentifyResponse(s network.Stream) { ids.consumeMessage(&mes, c) } -func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, remoteAddr ma.Multiaddr) { - // set protocols this node is currently handling - protos := ids.Host.Mux().Protocols() - mes.Protocols = make([]string, len(protos)) - for i, p := range protos { - mes.Protocols[i] = p +func (ids *IDService) getSnapshot() *identifySnapshot { + snapshot := new(identifySnapshot) + if cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()); ok { + snapshot.record = cab.GetPeerRecord(ids.Host.ID()) + if snapshot.record == nil { + log.Errorf("latest peer record does not exist. identify message incomplete!") + } } + snapshot.addrs = ids.Host.Addrs() + snapshot.protocols = ids.Host.Mux().Protocols() + return snapshot +} + +func (ids *IDService) populateMessage( + mes *pb.Identify, + conn network.Conn, + snapshot *identifySnapshot, +) { + remoteAddr := conn.RemoteMultiaddr() + localAddr := conn.LocalMultiaddr() + + // set protocols this node is currently handling + mes.Protocols = snapshot.protocols // observed address so other side is informed of their // "public" address, at least in relation to us. @@ -436,33 +449,22 @@ func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, r // populate unsigned addresses. // peers that do not yet support signed addresses will need this. - // set listen addrs, get our latest addrs from Host. - laddrs := ids.Host.Addrs() // Note: LocalMultiaddr is sometimes 0.0.0.0 viaLoopback := manet.IsIPLoopback(localAddr) || manet.IsIPLoopback(remoteAddr) - mes.ListenAddrs = make([][]byte, 0, len(laddrs)) - for _, addr := range laddrs { + mes.ListenAddrs = make([][]byte, 0, len(snapshot.addrs)) + for _, addr := range snapshot.addrs { if !viaLoopback && manet.IsIPLoopback(addr) { continue } mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes()) } - // populate signed record. - cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()) - if ok { - rec := cab.GetPeerRecord(ids.Host.ID()) - if rec == nil { - log.Errorf("latest peer record does not exist. identify message incomplete!") - } else { - recBytes, err := rec.Marshal() - if err != nil { - log.Errorf("error marshaling peer record: %v", err) - } else { - mes.SignedPeerRecord = recBytes - log.Debugf("%s sent peer record to %s", ids.Host.ID(), rp) - } - } + recBytes, err := snapshot.record.Marshal() + if err != nil { + log.Errorf("error marshaling peer record: %v", err) + } else { + mes.SignedPeerRecord = recBytes + log.Debugf("%s sent peer record to %s", ids.Host.ID(), conn.RemotePeer()) } // set our public key diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index d0006347a..e6092d6c0 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -490,8 +490,6 @@ func TestIdentifyDeltaOnProtocolChange(t *testing.T) { lk.Unlock() } } - - close(done) }() <-done diff --git a/p2p/protocol/identify/peer_loop.go b/p2p/protocol/identify/peer_loop.go index 7fc4021d4..8e66b0ef6 100644 --- a/p2p/protocol/identify/peer_loop.go +++ b/p2p/protocol/identify/peer_loop.go @@ -11,14 +11,21 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-libp2p-core/record" pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" ggio "github.com/gogo/protobuf/io" + ma "github.com/multiformats/go-multiaddr" ) var errProtocolNotSupported = errors.New("protocol not supported") -var isTesting = false + +type identifySnapshot struct { + protocols []string + addrs []ma.Multiaddr + record *record.Envelope +} type peerHandler struct { ids *IDService @@ -29,29 +36,24 @@ type peerHandler struct { pid peer.ID - msgMu sync.RWMutex - idMsgSnapshot *pb.Identify + snapshotMu sync.RWMutex + snapshot *identifySnapshot - pushCh chan struct{} - deltaCh chan struct{} - evalTestCh chan func() // for testing + pushCh chan struct{} + deltaCh chan struct{} } -func newPeerHandler(pid peer.ID, ids *IDService, initState *pb.Identify) *peerHandler { +func newPeerHandler(pid peer.ID, ids *IDService) *peerHandler { ph := &peerHandler{ ids: ids, pid: pid, - idMsgSnapshot: initState, + snapshot: ids.getSnapshot(), pushCh: make(chan struct{}, 1), deltaCh: make(chan struct{}, 1), } - if isTesting { - ph.evalTestCh = make(chan func()) - } - return ph } @@ -87,9 +89,6 @@ func (ph *peerHandler) loop() { log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err) } - case fnc := <-ph.evalTestCh: - fnc() - case <-ph.ctx.Done(): return } @@ -97,11 +96,6 @@ func (ph *peerHandler) loop() { } func (ph *peerHandler) sendDelta() error { - mes := ph.mkDelta() - if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) { - return nil - } - // send a push if the peer does not support the Delta protocol. if !ph.peerSupportsProtos([]string{IDDelta}) { log.Debugw("will send push as peer does not support delta", "peer", ph.pid) @@ -111,10 +105,11 @@ func (ph *peerHandler) sendDelta() error { return nil } - ph.msgMu.Lock() - // update our identify snapshot for this peer by applying the delta to it - ph.applyDelta(mes) - ph.msgMu.Unlock() + // extract a delta message, updating the last state. + mes := ph.nextDelta() + if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) { + return nil + } ds, err := ph.openStream([]string{IDDelta}) if err != nil { @@ -139,11 +134,13 @@ func (ph *peerHandler) sendPush() error { conn := dp.Conn() mes := &pb.Identify{} - ph.ids.populateMessage(mes, ph.pid, conn.LocalMultiaddr(), conn.RemoteMultiaddr()) - ph.msgMu.Lock() - ph.idMsgSnapshot = mes - ph.msgMu.Unlock() + snapshot := ph.ids.getSnapshot() + ph.snapshotMu.Lock() + ph.snapshot = snapshot + ph.snapshotMu.Unlock() + + ph.ids.populateMessage(mes, conn, snapshot) if err := ph.sendMessage(dp, mes); err != nil { return fmt.Errorf("failed to send push message: %w", err) @@ -151,21 +148,6 @@ func (ph *peerHandler) sendPush() error { return nil } -func (ph *peerHandler) applyDelta(mes *pb.Delta) { - for _, p1 := range mes.RmProtocols { - for j, p2 := range ph.idMsgSnapshot.Protocols { - if p2 == p1 { - ph.idMsgSnapshot.Protocols[j] = ph.idMsgSnapshot.Protocols[len(ph.idMsgSnapshot.Protocols)-1] - ph.idMsgSnapshot.Protocols = ph.idMsgSnapshot.Protocols[:len(ph.idMsgSnapshot.Protocols)-1] - } - } - } - - for _, p := range mes.AddedProtocols { - ph.idMsgSnapshot.Protocols = append(ph.idMsgSnapshot.Protocols, p) - } -} - func (ph *peerHandler) openStream(protos []string) (network.Stream, error) { // wait for the other peer to send us an Identify response on "all" connections we have with it // so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol @@ -217,10 +199,18 @@ func (ph *peerHandler) peerSupportsProtos(protos []string) bool { return true } -func (ph *peerHandler) mkDelta() *pb.Delta { - old := ph.idMsgSnapshot.GetProtocols() +func (ph *peerHandler) nextDelta() *pb.Delta { curr := ph.ids.Host.Mux().Protocols() + // Extract the old protocol list and replace the old snapshot with an + // updated one. + ph.snapshotMu.Lock() + snapshot := *ph.snapshot + old := snapshot.protocols + snapshot.protocols = curr + ph.snapshot = &snapshot + ph.snapshotMu.Unlock() + oldProtos := make(map[string]struct{}, len(old)) currProtos := make(map[string]struct{}, len(curr)) diff --git a/p2p/protocol/identify/peer_loop_test.go b/p2p/protocol/identify/peer_loop_test.go index 6eb3bd605..06286f978 100644 --- a/p2p/protocol/identify/peer_loop_test.go +++ b/p2p/protocol/identify/peer_loop_test.go @@ -9,72 +9,49 @@ import ( blhost "github.com/libp2p/go-libp2p-blankhost" swarmt "github.com/libp2p/go-libp2p-swarm/testing" - pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" "github.com/stretchr/testify/require" ) -func doeval(t *testing.T, ph *peerHandler, f func()) { - done := make(chan struct{}, 1) - ph.evalTestCh <- func() { - f() - done <- struct{}{} - } - <-done -} - func TestMakeApplyDelta(t *testing.T) { - isTesting = true ctx, cancel := context.WithCancel(context.Background()) defer cancel() h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) defer h1.Close() ids1 := NewIDService(h1) - ph := newPeerHandler(h1.ID(), ids1, &pb.Identify{}) + ph := newPeerHandler(h1.ID(), ids1) ph.start() defer ph.close() - m1 := ph.mkDelta() + m1 := ph.nextDelta() require.NotNil(t, m1) - // all the Id protocols must have been added - require.NotEmpty(t, m1.AddedProtocols) - doeval(t, ph, func() { - ph.applyDelta(m1) - }) + // We haven't changed anything since creating the peer handler + require.Empty(t, m1.AddedProtocols) h1.SetStreamHandler("p1", func(network.Stream) {}) - m2 := ph.mkDelta() + m2 := ph.nextDelta() require.Len(t, m2.AddedProtocols, 1) require.Contains(t, m2.AddedProtocols, "p1") require.Empty(t, m2.RmProtocols) - doeval(t, ph, func() { - ph.applyDelta(m2) - }) h1.SetStreamHandler("p2", func(network.Stream) {}) h1.SetStreamHandler("p3", func(stream network.Stream) {}) - m3 := ph.mkDelta() + m3 := ph.nextDelta() require.Len(t, m3.AddedProtocols, 2) require.Contains(t, m3.AddedProtocols, "p2") require.Contains(t, m3.AddedProtocols, "p3") require.Empty(t, m3.RmProtocols) - doeval(t, ph, func() { - ph.applyDelta(m3) - }) h1.RemoveStreamHandler("p3") - m4 := ph.mkDelta() + m4 := ph.nextDelta() require.Empty(t, m4.AddedProtocols) require.Len(t, m4.RmProtocols, 1) require.Contains(t, m4.RmProtocols, "p3") - doeval(t, ph, func() { - ph.applyDelta(m4) - }) h1.RemoveStreamHandler("p2") h1.RemoveStreamHandler("p1") - m5 := ph.mkDelta() + m5 := ph.nextDelta() require.Empty(t, m5.AddedProtocols) require.Len(t, m5.RmProtocols, 2) require.Contains(t, m5.RmProtocols, "p2") @@ -82,14 +59,13 @@ func TestMakeApplyDelta(t *testing.T) { } func TestHandlerClose(t *testing.T) { - isTesting = true ctx, cancel := context.WithCancel(context.Background()) defer cancel() h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) defer h1.Close() ids1 := NewIDService(h1) - ph := newPeerHandler(h1.ID(), ids1, nil) + ph := newPeerHandler(h1.ID(), ids1) ph.start() require.NoError(t, ph.close()) @@ -104,7 +80,7 @@ func TestPeerSupportsProto(t *testing.T) { ids1 := NewIDService(h1) rp := peer.ID("test") - ph := newPeerHandler(rp, ids1, nil) + ph := newPeerHandler(rp, ids1) require.NoError(t, h1.Peerstore().AddProtocols(rp, "test")) require.True(t, ph.peerSupportsProtos([]string{"test"})) require.False(t, ph.peerSupportsProtos([]string{"random"}))