diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 78795f17b..f51733b1b 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -159,9 +159,9 @@ loop: // Enqueue the peer's addresses relevant to this request in dq and // track dials to the addresses relevant to this request. - c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer) - if c != nil || err != nil { - req.resch <- dialResponse{conn: c, err: err} + c := w.s.bestAcceptableConnToPeer(req.ctx, w.peer) + if c != nil { + req.resch <- dialResponse{conn: c} continue loop } @@ -373,7 +373,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { // all addrs have erred, dispatch dial error // but first do a last one check in case an acceptable connection has landed from // a simultaneous dial that started later and added new acceptable addrs - c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) + c := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) if c != nil { pr.req.resch <- dialResponse{conn: c} } else { diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 634e921f5..57361e8d5 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -429,6 +429,8 @@ func (s *Swarm) StreamHandler() network.StreamHandler { // NewStream creates a new stream on any available connection to peer, dialing // if necessary. +// Use network.WithUseTransient to open a stream over a transient(relayed) +// connection. func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) { log.Debugf("[%s] opening stream to peer [%s]", s.local, p) @@ -447,10 +449,7 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error dials := 0 for { // will prefer direct connections over relayed connections for opening streams - c, err := s.bestAcceptableConnToPeer(ctx, p) - if err != nil { - return nil, err - } + c := s.bestAcceptableConnToPeer(ctx, p) if c == nil { if nodial, _ := network.GetNoDial(ctx); nodial { @@ -548,26 +547,17 @@ func (s *Swarm) bestConnToPeer(p peer.ID) *Conn { return best } -// - Returns the best "acceptable" connection, if available. -// - Returns nothing if no such connection exists, but if we should try dialing anyways. -// - Returns an error if no such connection exists, but we should not try dialing. -func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) (*Conn, error) { +// bestAcceptableConnToPeer returns the best acceptable connection, considering the passed in ctx. +// If network.WithForceDirectDial is used, it only returns a direct connections, ignoring +// any transient (relayed) connections to the peer. +func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn { conn := s.bestConnToPeer(p) - if conn == nil { - return nil, nil - } forceDirect, _ := network.GetForceDirectDial(ctx) if forceDirect && !isDirectConn(conn) { - return nil, nil - } - - useTransient, _ := network.GetUseTransient(ctx) - if useTransient || !conn.Stat().Transient { - return conn, nil + return nil } - - return nil, network.ErrTransientConn + return conn } func isDirectConn(c *Conn) bool { diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index dbdd4d3c0..761ab96c9 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -216,7 +216,8 @@ func (db *DialBackoff) cleanup() { } } -// DialPeer connects to a peer. +// DialPeer connects to a peer. Use network.WithForceDirectDial to force a +// direct connection. // // The idea is that the client of Swarm does not need to know what network // the connection will happen over. Swarm can use whichever it choses. @@ -246,11 +247,10 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { return nil, ErrDialToSelf } - // check if we already have an open (usable) connection first, or can't have a usable - // connection. - conn, err := s.bestAcceptableConnToPeer(ctx, p) - if conn != nil || err != nil { - return conn, err + // check if we already have an open (usable) connection. + conn := s.bestAcceptableConnToPeer(ctx, p) + if conn != nil { + return conn, nil } if s.gater != nil && !s.gater.InterceptPeerDial(p) { diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go new file mode 100644 index 000000000..2ddadb357 --- /dev/null +++ b/p2p/test/swarm/swarm_test.go @@ -0,0 +1,70 @@ +package swarm_test + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func TestDialPeerTransientConnection(t *testing.T) { + h1, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relay.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + err = h1.Connect(context.Background(), relay1info) + require.NoError(t, err) + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) + + h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL) + + // swarm.DialPeer should connect over transient connections + conn1, err := h1.Network().DialPeer(context.Background(), h2.ID()) + require.NoError(t, err) + require.NotNil(t, conn1) + + // Test that repeated calls return the same connection. + conn2, err := h1.Network().DialPeer(context.Background(), h2.ID()) + require.NoError(t, err) + require.NotNil(t, conn2) + + require.Equal(t, conn1, conn2) + + // swarm.DialPeer should fail if forceDirect is used + ctx := network.WithForceDirectDial(context.Background(), "test") + conn, err := h1.Network().DialPeer(ctx, h2.ID()) + require.Error(t, err) + require.Nil(t, conn) +}