Browse Source

swarm: wait for transient connections to upgrade for NewStream (#2542)

pull/2607/head
Sukun 1 year ago
committed by GitHub
parent
commit
17a464978a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 111
      p2p/net/swarm/swarm.go
  2. 88
      p2p/test/basichost/basic_host_test.go
  3. 92
      p2p/test/swarm/swarm_test.go

111
p2p/net/swarm/swarm.go

@ -17,6 +17,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/core/transport"
"golang.org/x/exp/slices"
logging "github.com/ipfs/go-log/v2" logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
@ -172,6 +173,11 @@ type Swarm struct {
m map[network.Notifiee]struct{} m map[network.Notifiee]struct{}
} }
directConnNotifs struct {
sync.Mutex
m map[peer.ID][]chan struct{}
}
transports struct { transports struct {
sync.RWMutex sync.RWMutex
m map[int]transport.Transport m map[int]transport.Transport
@ -231,6 +237,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts
s.listeners.m = make(map[transport.Listener]struct{}) s.listeners.m = make(map[transport.Listener]struct{})
s.transports.m = make(map[int]transport.Transport) s.transports.m = make(map[int]transport.Transport)
s.notifs.m = make(map[network.Notifiee]struct{}) s.notifs.m = make(map[network.Notifiee]struct{})
s.directConnNotifs.m = make(map[peer.ID][]chan struct{})
for _, opt := range opts { for _, opt := range opts {
if err := opt(s); err != nil { if err := opt(s); err != nil {
@ -390,6 +397,19 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
c.notifyLk.Lock() c.notifyLk.Lock()
s.conns.Unlock() s.conns.Unlock()
// Notify goroutines waiting for a direct connection
if !c.Stat().Transient {
// Go routines interested in waiting for direct connection first acquire this lock
// and then acquire s.conns.RLock. Do not acquire this lock before conns.Unlock to
// prevent deadlock.
s.directConnNotifs.Lock()
for _, ch := range s.directConnNotifs.m[p] {
close(ch)
}
delete(s.directConnNotifs.m, p)
s.directConnNotifs.Unlock()
}
// Emit event after releasing `s.conns` lock so that a consumer can still // Emit event after releasing `s.conns` lock so that a consumer can still
// use swarm methods that need the `s.conns` lock. // use swarm methods that need the `s.conns` lock.
if isFirstConnection { if isFirstConnection {
@ -436,46 +456,103 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error
// Algorithm: // Algorithm:
// 1. Find the best connection, otherwise, dial. // 1. Find the best connection, otherwise, dial.
// 2. Try opening a stream. // 2. If the best connection is transient, wait for a direct conn via conn
// 3. If the underlying connection is, in fact, closed, close the outer // reversal or hole punching.
// 3. Try opening a stream.
// 4. If the underlying connection is, in fact, closed, close the outer
// connection and try again. We do this in case we have a closed // connection and try again. We do this in case we have a closed
// connection but don't notice it until we actually try to open a // connection but don't notice it until we actually try to open a
// stream. // stream.
// //
// Note: We only dial once.
//
// TODO: Try all connections even if we get an error opening a stream on // TODO: Try all connections even if we get an error opening a stream on
// a non-closed connection. // a non-closed connection.
dials := 0 numDials := 0
for { for {
// will prefer direct connections over relayed connections for opening streams c := s.bestConnToPeer(p)
c := s.bestAcceptableConnToPeer(ctx, p)
if c == nil { if c == nil {
if nodial, _ := network.GetNoDial(ctx); nodial { if nodial, _ := network.GetNoDial(ctx); !nodial {
numDials++
if numDials > DialAttempts {
return nil, errors.New("max dial attempts exceeded")
}
var err error
c, err = s.dialPeer(ctx, p)
if err != nil {
return nil, err
}
} else {
return nil, network.ErrNoConn return nil, network.ErrNoConn
} }
}
if dials >= DialAttempts { useTransient, _ := network.GetUseTransient(ctx)
return nil, errors.New("max dial attempts exceeded") if !useTransient && c.Stat().Transient {
}
dials++
var err error var err error
c, err = s.dialPeer(ctx, p) c, err = s.waitForDirectConn(ctx, p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
s, err := c.NewStream(ctx) str, err := c.NewStream(ctx)
if err != nil { if err != nil {
if c.conn.IsClosed() { if c.conn.IsClosed() {
continue continue
} }
return nil, err return nil, err
} }
return s, nil return str, nil
}
}
// waitForDirectConn waits for a direct connection established through hole punching or connection reversal.
func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) {
s.directConnNotifs.Lock()
c := s.bestConnToPeer(p)
if c == nil {
s.directConnNotifs.Unlock()
return nil, network.ErrNoConn
} else if !c.Stat().Transient {
s.directConnNotifs.Unlock()
return c, nil
}
// Wait for transient connection to upgrade to a direct connection either by
// connection reversal or hole punching.
ch := make(chan struct{})
s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch)
s.directConnNotifs.Unlock()
// apply the DialPeer timeout
ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx))
defer cancel()
// Wait for notification.
select {
case <-ctx.Done():
// Remove ourselves from the notification list
s.directConnNotifs.Lock()
defer s.directConnNotifs.Unlock()
s.directConnNotifs.m[p] = slices.DeleteFunc(
s.directConnNotifs.m[p],
func(c chan struct{}) bool { return c == ch },
)
if len(s.directConnNotifs.m[p]) == 0 {
delete(s.directConnNotifs.m, p)
}
return nil, ctx.Err()
case <-ch:
// We do not need to remove ourselves from the list here as the notifier
// clears the map entry
c := s.bestConnToPeer(p)
if c == nil {
return nil, network.ErrNoConn
}
if c.Stat().Transient {
return nil, network.ErrTransientConn
}
return c, nil
} }
} }

88
p2p/test/basichost/basic_host_test.go

@ -4,13 +4,16 @@ import (
"context" "context"
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -62,11 +65,92 @@ func TestNoStreamOverTransientConnection(t *testing.T) {
err = h1.Connect(context.Background(), h2Info) err = h1.Connect(context.Background(), h2Info)
require.NoError(t, err) require.NoError(t, err)
ctx := network.WithNoDial(context.Background(), "test") ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
_, err = h1.NewStream(ctx, h2.ID(), "/testprotocol") _, err = h1.NewStream(ctx, h2.ID(), "/testprotocol")
require.ErrorIs(t, err, network.ErrTransientConn) require.Error(t, err)
_, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol") _, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol")
require.NoError(t, err) require.NoError(t, err)
} }
func TestNewStreamTransientConnection(t *testing.T) {
h1, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)
h2, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)
relay1, err := libp2p.New()
require.NoError(t, err)
_, err = relay.New(relay1)
require.NoError(t, err)
relay1info := peer.AddrInfo{
ID: relay1.ID(),
Addrs: relay1.Addrs(),
}
err = h1.Connect(context.Background(), relay1info)
require.NoError(t, err)
err = h2.Connect(context.Background(), relay1info)
require.NoError(t, err)
h2.SetStreamHandler("/testprotocol", func(s network.Stream) {
fmt.Println("testprotocol")
// End the example
s.Close()
})
_, err = client.Reserve(context.Background(), h2, relay1info)
require.NoError(t, err)
relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())
h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)
// NewStream should block transient connections till we have a direct connection
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
s, err := h1.NewStream(ctx, h2.ID(), "/testprotocol")
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, s)
// NewStream should return a stream if a direct connection is established
// while waiting
done := make(chan bool, 2)
go func() {
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.TempAddrTTL)
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
s, err = h1.NewStream(ctx, h2.ID(), "/testprotocol")
require.NoError(t, err)
require.NotNil(t, s)
defer s.Close()
require.Equal(t, s.Conn().Stat().Direction, network.DirInbound)
done <- true
}()
go func() {
// connect h2 to h1 simulating connection reversal
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
ctx = network.WithForceDirectDial(ctx, "test")
err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()})
assert.NoError(t, err)
done <- true
}()
<-done
<-done
}

92
p2p/test/swarm/swarm_test.go

@ -3,6 +3,7 @@ package swarm_test
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
@ -11,6 +12,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -68,3 +70,93 @@ func TestDialPeerTransientConnection(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Nil(t, conn) require.Nil(t, conn)
} }
func TestNewStreamTransientConnection(t *testing.T) {
h1, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)
h2, err := libp2p.New(
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
libp2p.EnableRelay(),
)
require.NoError(t, err)
relay1, err := libp2p.New()
require.NoError(t, err)
_, err = relay.New(relay1)
require.NoError(t, err)
relay1info := peer.AddrInfo{
ID: relay1.ID(),
Addrs: relay1.Addrs(),
}
err = h1.Connect(context.Background(), relay1info)
require.NoError(t, err)
err = h2.Connect(context.Background(), relay1info)
require.NoError(t, err)
_, err = client.Reserve(context.Background(), h2, relay1info)
require.NoError(t, err)
relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())
h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)
// WithUseTransient should succeed
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
ctx = network.WithUseTransient(ctx, "test")
s, err := h1.Network().NewStream(ctx, h2.ID())
require.NoError(t, err)
require.NotNil(t, s)
defer s.Close()
// Without WithUseTransient should fail with context deadline exceeded
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
s, err = h1.Network().NewStream(ctx, h2.ID())
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, s)
// Provide h2's direct address to h1.
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.TempAddrTTL)
// network.NoDial should also fail
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
s, err = h1.Network().NewStream(ctx, h2.ID())
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, s)
done := make(chan bool, 2)
// NewStream should return a stream if an incoming direct connection is established
go func() {
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
s, err = h1.Network().NewStream(ctx, h2.ID())
assert.NoError(t, err)
assert.NotNil(t, s)
defer s.Close()
require.Equal(t, s.Conn().Stat().Direction, network.DirInbound)
done <- true
}()
go func() {
// connect h2 to h1 simulating connection reversal
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
ctx = network.WithForceDirectDial(ctx, "test")
err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()})
assert.NoError(t, err)
done <- true
}()
<-done
<-done
}

Loading…
Cancel
Save