mirror of https://github.com/libp2p/go-libp2p.git
Marten Seemann
3 years ago
8 changed files with 741 additions and 0 deletions
@ -0,0 +1,266 @@ |
|||
//go:build !windows
|
|||
// +build !windows
|
|||
|
|||
package tcp |
|||
|
|||
import ( |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/marten-seemann/tcp" |
|||
"github.com/mikioh/tcpinfo" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
"github.com/prometheus/client_golang/prometheus" |
|||
) |
|||
|
|||
var ( |
|||
newConns *prometheus.CounterVec |
|||
closedConns *prometheus.CounterVec |
|||
segsSentDesc *prometheus.Desc |
|||
segsRcvdDesc *prometheus.Desc |
|||
bytesSentDesc *prometheus.Desc |
|||
bytesRcvdDesc *prometheus.Desc |
|||
) |
|||
|
|||
const collectFrequency = 10 * time.Second |
|||
|
|||
var collector *aggregatingCollector |
|||
|
|||
func init() { |
|||
segsSentDesc = prometheus.NewDesc("tcp_sent_segments_total", "TCP segments sent", nil, nil) |
|||
segsRcvdDesc = prometheus.NewDesc("tcp_rcvd_segments_total", "TCP segments received", nil, nil) |
|||
bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil) |
|||
bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil) |
|||
|
|||
collector = newAggregatingCollector() |
|||
prometheus.MustRegister(collector) |
|||
|
|||
const direction = "direction" |
|||
|
|||
newConns = prometheus.NewCounterVec( |
|||
prometheus.CounterOpts{ |
|||
Name: "tcp_connections_new_total", |
|||
Help: "TCP new connections", |
|||
}, |
|||
[]string{direction}, |
|||
) |
|||
prometheus.MustRegister(newConns) |
|||
closedConns = prometheus.NewCounterVec( |
|||
prometheus.CounterOpts{ |
|||
Name: "tcp_connections_closed_total", |
|||
Help: "TCP connections closed", |
|||
}, |
|||
[]string{direction}, |
|||
) |
|||
prometheus.MustRegister(closedConns) |
|||
} |
|||
|
|||
type aggregatingCollector struct { |
|||
cronOnce sync.Once |
|||
|
|||
mutex sync.Mutex |
|||
highestID uint64 |
|||
conns map[uint64] /* id */ *tracingConn |
|||
rtts prometheus.Histogram |
|||
connDurations prometheus.Histogram |
|||
segsSent, segsRcvd uint64 |
|||
bytesSent, bytesRcvd uint64 |
|||
} |
|||
|
|||
var _ prometheus.Collector = &aggregatingCollector{} |
|||
|
|||
func newAggregatingCollector() *aggregatingCollector { |
|||
c := &aggregatingCollector{ |
|||
conns: make(map[uint64]*tracingConn), |
|||
rtts: prometheus.NewHistogram(prometheus.HistogramOpts{ |
|||
Name: "tcp_rtt", |
|||
Help: "TCP round trip time", |
|||
Buckets: prometheus.ExponentialBuckets(0.001, 1.25, 40), // 1ms to ~6000ms
|
|||
}), |
|||
connDurations: prometheus.NewHistogram(prometheus.HistogramOpts{ |
|||
Name: "tcp_connection_duration", |
|||
Help: "TCP Connection Duration", |
|||
Buckets: prometheus.ExponentialBuckets(1, 1.5, 40), // 1s to ~12 weeks
|
|||
}), |
|||
} |
|||
return c |
|||
} |
|||
|
|||
func (c *aggregatingCollector) AddConn(t *tracingConn) uint64 { |
|||
c.mutex.Lock() |
|||
defer c.mutex.Unlock() |
|||
c.highestID++ |
|||
c.conns[c.highestID] = t |
|||
return c.highestID |
|||
} |
|||
|
|||
func (c *aggregatingCollector) removeConn(id uint64) { |
|||
delete(c.conns, id) |
|||
} |
|||
|
|||
func (c *aggregatingCollector) Describe(descs chan<- *prometheus.Desc) { |
|||
descs <- c.rtts.Desc() |
|||
descs <- c.connDurations.Desc() |
|||
if hasSegmentCounter { |
|||
descs <- segsSentDesc |
|||
descs <- segsRcvdDesc |
|||
} |
|||
if hasByteCounter { |
|||
descs <- bytesSentDesc |
|||
descs <- bytesRcvdDesc |
|||
} |
|||
} |
|||
|
|||
func (c *aggregatingCollector) cron() { |
|||
ticker := time.NewTicker(collectFrequency) |
|||
defer ticker.Stop() |
|||
|
|||
for now := range ticker.C { |
|||
c.gatherMetrics(now) |
|||
} |
|||
} |
|||
|
|||
func (c *aggregatingCollector) gatherMetrics(now time.Time) { |
|||
c.mutex.Lock() |
|||
defer c.mutex.Unlock() |
|||
|
|||
c.segsSent = 0 |
|||
c.segsRcvd = 0 |
|||
c.bytesSent = 0 |
|||
c.bytesRcvd = 0 |
|||
for _, conn := range c.conns { |
|||
info, err := conn.getTCPInfo() |
|||
if err != nil { |
|||
if strings.Contains(err.Error(), "use of closed network connection") { |
|||
continue |
|||
} |
|||
log.Errorf("Failed to get TCP info: %s", err) |
|||
continue |
|||
} |
|||
if hasSegmentCounter { |
|||
c.segsSent += getSegmentsSent(info) |
|||
c.segsRcvd += getSegmentsRcvd(info) |
|||
} |
|||
if hasByteCounter { |
|||
c.bytesSent += getBytesSent(info) |
|||
c.bytesRcvd += getBytesRcvd(info) |
|||
} |
|||
c.rtts.Observe(info.RTT.Seconds()) |
|||
c.connDurations.Observe(now.Sub(conn.startTime).Seconds()) |
|||
} |
|||
} |
|||
|
|||
func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) { |
|||
// Start collecting the metrics collection the first time Collect is called.
|
|||
c.cronOnce.Do(func() { |
|||
c.gatherMetrics(time.Now()) |
|||
go c.cron() |
|||
}) |
|||
|
|||
c.mutex.Lock() |
|||
defer c.mutex.Unlock() |
|||
|
|||
metrics <- c.rtts |
|||
metrics <- c.connDurations |
|||
if hasSegmentCounter { |
|||
segsSentMetric, err := prometheus.NewConstMetric(segsSentDesc, prometheus.CounterValue, float64(c.segsSent)) |
|||
if err != nil { |
|||
log.Errorf("creating tcp_sent_segments_total metric failed: %v", err) |
|||
return |
|||
} |
|||
segsRcvdMetric, err := prometheus.NewConstMetric(segsRcvdDesc, prometheus.CounterValue, float64(c.segsRcvd)) |
|||
if err != nil { |
|||
log.Errorf("creating tcp_rcvd_segments_total metric failed: %v", err) |
|||
return |
|||
} |
|||
metrics <- segsSentMetric |
|||
metrics <- segsRcvdMetric |
|||
} |
|||
if hasByteCounter { |
|||
bytesSentMetric, err := prometheus.NewConstMetric(bytesSentDesc, prometheus.CounterValue, float64(c.bytesSent)) |
|||
if err != nil { |
|||
log.Errorf("creating tcp_sent_bytes metric failed: %v", err) |
|||
return |
|||
} |
|||
bytesRcvdMetric, err := prometheus.NewConstMetric(bytesRcvdDesc, prometheus.CounterValue, float64(c.bytesRcvd)) |
|||
if err != nil { |
|||
log.Errorf("creating tcp_rcvd_bytes metric failed: %v", err) |
|||
return |
|||
} |
|||
metrics <- bytesSentMetric |
|||
metrics <- bytesRcvdMetric |
|||
} |
|||
} |
|||
|
|||
func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { |
|||
c.mutex.Lock() |
|||
collector.removeConn(conn.id) |
|||
c.mutex.Unlock() |
|||
closedConns.WithLabelValues(direction).Inc() |
|||
} |
|||
|
|||
type tracingConn struct { |
|||
id uint64 |
|||
|
|||
startTime time.Time |
|||
isClient bool |
|||
|
|||
manet.Conn |
|||
tcpConn *tcp.Conn |
|||
} |
|||
|
|||
func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { |
|||
conn, err := tcp.NewConn(c) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
tc := &tracingConn{ |
|||
startTime: time.Now(), |
|||
isClient: isClient, |
|||
Conn: c, |
|||
tcpConn: conn, |
|||
} |
|||
tc.id = collector.AddConn(tc) |
|||
newConns.WithLabelValues(tc.getDirection()).Inc() |
|||
return tc, nil |
|||
} |
|||
|
|||
func (c *tracingConn) getDirection() string { |
|||
if c.isClient { |
|||
return "outgoing" |
|||
} |
|||
return "incoming" |
|||
} |
|||
|
|||
func (c *tracingConn) Close() error { |
|||
collector.ClosedConn(c, c.getDirection()) |
|||
return c.Conn.Close() |
|||
} |
|||
|
|||
func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { |
|||
var o tcpinfo.Info |
|||
var b [256]byte |
|||
i, err := c.tcpConn.Option(o.Level(), o.Name(), b[:]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
info := i.(*tcpinfo.Info) |
|||
return info, nil |
|||
} |
|||
|
|||
type tracingListener struct { |
|||
manet.Listener |
|||
} |
|||
|
|||
func newTracingListener(l manet.Listener) *tracingListener { |
|||
return &tracingListener{Listener: l} |
|||
} |
|||
|
|||
func (l *tracingListener) Accept() (manet.Conn, error) { |
|||
conn, err := l.Listener.Accept() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newTracingConn(conn, false) |
|||
} |
@ -0,0 +1,16 @@ |
|||
//go:build darwin
|
|||
// +build darwin
|
|||
|
|||
package tcp |
|||
|
|||
import "github.com/mikioh/tcpinfo" |
|||
|
|||
const ( |
|||
hasSegmentCounter = true |
|||
hasByteCounter = true |
|||
) |
|||
|
|||
func getSegmentsSent(info *tcpinfo.Info) uint64 { return info.Sys.SegsSent } |
|||
func getSegmentsRcvd(info *tcpinfo.Info) uint64 { return info.Sys.SegsReceived } |
|||
func getBytesSent(info *tcpinfo.Info) uint64 { return info.Sys.BytesSent } |
|||
func getBytesRcvd(info *tcpinfo.Info) uint64 { return info.Sys.BytesReceived } |
@ -0,0 +1,16 @@ |
|||
//go:build !linux && !darwin && !windows
|
|||
// +build !linux,!darwin,!windows
|
|||
|
|||
package tcp |
|||
|
|||
import "github.com/mikioh/tcpinfo" |
|||
|
|||
const ( |
|||
hasSegmentCounter = false |
|||
hasByteCounter = false |
|||
) |
|||
|
|||
func getSegmentsSent(info *tcpinfo.Info) uint64 { return 0 } |
|||
func getSegmentsRcvd(info *tcpinfo.Info) uint64 { return 0 } |
|||
func getBytesSent(info *tcpinfo.Info) uint64 { return 0 } |
|||
func getBytesRcvd(info *tcpinfo.Info) uint64 { return 0 } |
@ -0,0 +1,16 @@ |
|||
//go:build linux
|
|||
// +build linux
|
|||
|
|||
package tcp |
|||
|
|||
import "github.com/mikioh/tcpinfo" |
|||
|
|||
const ( |
|||
hasSegmentCounter = true |
|||
hasByteCounter = false |
|||
) |
|||
|
|||
func getSegmentsSent(info *tcpinfo.Info) uint64 { return uint64(info.Sys.SegsOut) } |
|||
func getSegmentsRcvd(info *tcpinfo.Info) uint64 { return uint64(info.Sys.SegsIn) } |
|||
func getBytesSent(info *tcpinfo.Info) uint64 { return 0 } |
|||
func getBytesRcvd(info *tcpinfo.Info) uint64 { return 0 } |
@ -0,0 +1,9 @@ |
|||
//go:build windows
|
|||
// +build windows
|
|||
|
|||
package tcp |
|||
|
|||
import manet "github.com/multiformats/go-multiaddr/net" |
|||
|
|||
func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil } |
|||
func newTracingListener(l manet.Listener) manet.Listener { return l } |
@ -0,0 +1,35 @@ |
|||
package tcp |
|||
|
|||
import ( |
|||
"os" |
|||
"strings" |
|||
|
|||
"github.com/libp2p/go-reuseport" |
|||
) |
|||
|
|||
// envReuseport is the env variable name used to turn off reuse port.
|
|||
// It default to true.
|
|||
const envReuseport = "LIBP2P_TCP_REUSEPORT" |
|||
|
|||
// envReuseportVal stores the value of envReuseport. defaults to true.
|
|||
var envReuseportVal = true |
|||
|
|||
func init() { |
|||
v := strings.ToLower(os.Getenv(envReuseport)) |
|||
if v == "false" || v == "f" || v == "0" { |
|||
envReuseportVal = false |
|||
log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v) |
|||
} |
|||
} |
|||
|
|||
// ReuseportIsAvailable returns whether reuseport is available to be used. This
|
|||
// is here because we want to be able to turn reuseport on and off selectively.
|
|||
// For now we use an ENV variable, as this handles our pressing need:
|
|||
//
|
|||
// LIBP2P_TCP_REUSEPORT=false ipfs daemon
|
|||
//
|
|||
// If this becomes a sought after feature, we could add this to the config.
|
|||
// In the end, reuseport is a stop-gap.
|
|||
func ReuseportIsAvailable() bool { |
|||
return envReuseportVal && reuseport.Available() |
|||
} |
@ -0,0 +1,239 @@ |
|||
package tcp |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"net" |
|||
"os" |
|||
"runtime" |
|||
"syscall" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
"github.com/libp2p/go-libp2p-core/transport" |
|||
|
|||
rtpt "github.com/libp2p/go-reuseport-transport" |
|||
|
|||
logging "github.com/ipfs/go-log/v2" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
mafmt "github.com/multiformats/go-multiaddr-fmt" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
) |
|||
|
|||
const defaultConnectTimeout = 5 * time.Second |
|||
|
|||
var log = logging.Logger("tcp-tpt") |
|||
|
|||
const keepAlivePeriod = 30 * time.Second |
|||
|
|||
type canKeepAlive interface { |
|||
SetKeepAlive(bool) error |
|||
SetKeepAlivePeriod(time.Duration) error |
|||
} |
|||
|
|||
var _ canKeepAlive = &net.TCPConn{} |
|||
|
|||
func tryKeepAlive(conn net.Conn, keepAlive bool) { |
|||
keepAliveConn, ok := conn.(canKeepAlive) |
|||
if !ok { |
|||
log.Errorf("Can't set TCP keepalives.") |
|||
return |
|||
} |
|||
if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil { |
|||
// Sometimes we seem to get "invalid argument" results from this function on Darwin.
|
|||
// This might be due to a closed connection, but I can't reproduce that on Linux.
|
|||
//
|
|||
// But there's nothing we can do about invalid arguments, so we'll drop this to a
|
|||
// debug.
|
|||
if errors.Is(err, os.ErrInvalid) || errors.Is(err, syscall.EINVAL) { |
|||
log.Debugw("failed to enable TCP keepalive", "error", err) |
|||
} else { |
|||
log.Errorw("failed to enable TCP keepalive", "error", err) |
|||
} |
|||
return |
|||
} |
|||
|
|||
if runtime.GOOS != "openbsd" { |
|||
if err := keepAliveConn.SetKeepAlivePeriod(keepAlivePeriod); err != nil { |
|||
log.Errorw("failed set keepalive period", "error", err) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// try to set linger on the connection, if possible.
|
|||
func tryLinger(conn net.Conn, sec int) { |
|||
type canLinger interface { |
|||
SetLinger(int) error |
|||
} |
|||
|
|||
if lingerConn, ok := conn.(canLinger); ok { |
|||
_ = lingerConn.SetLinger(sec) |
|||
} |
|||
} |
|||
|
|||
type tcpListener struct { |
|||
manet.Listener |
|||
sec int |
|||
} |
|||
|
|||
func (ll *tcpListener) Accept() (manet.Conn, error) { |
|||
c, err := ll.Listener.Accept() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
tryLinger(c, ll.sec) |
|||
tryKeepAlive(c, true) |
|||
// We're not calling OpenConnection in the resource manager here,
|
|||
// since the manet.Conn doesn't allow us to save the scope.
|
|||
// It's the caller's (usually the go-libp2p-transport-upgrader) responsibility
|
|||
// to call the resource manager.
|
|||
return c, nil |
|||
} |
|||
|
|||
type Option func(*TcpTransport) error |
|||
|
|||
func DisableReuseport() Option { |
|||
return func(tr *TcpTransport) error { |
|||
tr.disableReuseport = true |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
func WithConnectionTimeout(d time.Duration) Option { |
|||
return func(tr *TcpTransport) error { |
|||
tr.connectTimeout = d |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// TcpTransport is the TCP transport.
|
|||
type TcpTransport struct { |
|||
// Connection upgrader for upgrading insecure stream connections to
|
|||
// secure multiplex connections.
|
|||
Upgrader transport.Upgrader |
|||
|
|||
// Explicitly disable reuseport.
|
|||
disableReuseport bool |
|||
|
|||
// TCP connect timeout
|
|||
connectTimeout time.Duration |
|||
|
|||
rcmgr network.ResourceManager |
|||
|
|||
reuse rtpt.Transport |
|||
} |
|||
|
|||
var _ transport.Transport = &TcpTransport{} |
|||
|
|||
// NewTCPTransport creates a tcp transport object that tracks dialers and listeners
|
|||
// created. It represents an entire TCP stack (though it might not necessarily be).
|
|||
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { |
|||
if rcmgr == nil { |
|||
rcmgr = network.NullResourceManager |
|||
} |
|||
tr := &TcpTransport{ |
|||
Upgrader: upgrader, |
|||
connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option
|
|||
rcmgr: rcmgr, |
|||
} |
|||
for _, o := range opts { |
|||
if err := o(tr); err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
return tr, nil |
|||
} |
|||
|
|||
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP)) |
|||
|
|||
// CanDial returns true if this transport believes it can dial the given
|
|||
// multiaddr.
|
|||
func (t *TcpTransport) CanDial(addr ma.Multiaddr) bool { |
|||
return dialMatcher.Matches(addr) |
|||
} |
|||
|
|||
func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { |
|||
// Apply the deadline iff applicable
|
|||
if t.connectTimeout > 0 { |
|||
var cancel context.CancelFunc |
|||
ctx, cancel = context.WithTimeout(ctx, t.connectTimeout) |
|||
defer cancel() |
|||
} |
|||
|
|||
if t.UseReuseport() { |
|||
return t.reuse.DialContext(ctx, raddr) |
|||
} |
|||
var d manet.Dialer |
|||
return d.DialContext(ctx, raddr) |
|||
} |
|||
|
|||
// Dial dials the peer at the remote address.
|
|||
func (t *TcpTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { |
|||
connScope, err := t.rcmgr.OpenConnection(network.DirOutbound, true) |
|||
if err != nil { |
|||
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) |
|||
return nil, err |
|||
} |
|||
if err := connScope.SetPeer(p); err != nil { |
|||
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) |
|||
connScope.Done() |
|||
return nil, err |
|||
} |
|||
conn, err := t.maDial(ctx, raddr) |
|||
if err != nil { |
|||
connScope.Done() |
|||
return nil, err |
|||
} |
|||
// Set linger to 0 so we never get stuck in the TIME-WAIT state. When
|
|||
// linger is 0, connections are _reset_ instead of closed with a FIN.
|
|||
// This means we can immediately reuse the 5-tuple and reconnect.
|
|||
tryLinger(conn, 0) |
|||
tryKeepAlive(conn, true) |
|||
c, err := newTracingConn(conn, true) |
|||
if err != nil { |
|||
connScope.Done() |
|||
return nil, err |
|||
} |
|||
direction := network.DirOutbound |
|||
if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { |
|||
direction = network.DirInbound |
|||
} |
|||
return t.Upgrader.Upgrade(ctx, t, c, direction, p, connScope) |
|||
} |
|||
|
|||
// UseReuseport returns true if reuseport is enabled and available.
|
|||
func (t *TcpTransport) UseReuseport() bool { |
|||
return !t.disableReuseport && ReuseportIsAvailable() |
|||
} |
|||
|
|||
func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { |
|||
if t.UseReuseport() { |
|||
return t.reuse.Listen(laddr) |
|||
} |
|||
return manet.Listen(laddr) |
|||
} |
|||
|
|||
// Listen listens on the given multiaddr.
|
|||
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { |
|||
list, err := t.maListen(laddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
list = newTracingListener(&tcpListener{list, 0}) |
|||
return t.Upgrader.UpgradeListener(t, list), nil |
|||
} |
|||
|
|||
// Protocols returns the list of terminal protocols this transport can dial.
|
|||
func (t *TcpTransport) Protocols() []int { |
|||
return []int{ma.P_TCP} |
|||
} |
|||
|
|||
// Proxy always returns false for the TCP transport.
|
|||
func (t *TcpTransport) Proxy() bool { |
|||
return false |
|||
} |
|||
|
|||
func (t *TcpTransport) String() string { |
|||
return "TCP" |
|||
} |
@ -0,0 +1,144 @@ |
|||
package tcp |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"testing" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/crypto" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
"github.com/libp2p/go-libp2p-core/sec" |
|||
"github.com/libp2p/go-libp2p-core/sec/insecure" |
|||
"github.com/libp2p/go-libp2p-core/transport" |
|||
|
|||
csms "github.com/libp2p/go-conn-security-multistream" |
|||
mplex "github.com/libp2p/go-libp2p-mplex" |
|||
mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" |
|||
ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" |
|||
tptu "github.com/libp2p/go-libp2p-transport-upgrader" |
|||
|
|||
ma "github.com/multiformats/go-multiaddr" |
|||
|
|||
"github.com/golang/mock/gomock" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func TestTcpTransport(t *testing.T) { |
|||
for i := 0; i < 2; i++ { |
|||
peerA, ia := makeInsecureMuxer(t) |
|||
_, ib := makeInsecureMuxer(t) |
|||
|
|||
ua, err := tptu.New(ia, new(mplex.Transport)) |
|||
require.NoError(t, err) |
|||
ta, err := NewTCPTransport(ua, nil) |
|||
require.NoError(t, err) |
|||
ub, err := tptu.New(ib, new(mplex.Transport)) |
|||
require.NoError(t, err) |
|||
tb, err := NewTCPTransport(ub, nil) |
|||
require.NoError(t, err) |
|||
|
|||
zero := "/ip4/127.0.0.1/tcp/0" |
|||
ttransport.SubtestTransport(t, ta, tb, zero, peerA) |
|||
|
|||
envReuseportVal = false |
|||
} |
|||
envReuseportVal = true |
|||
} |
|||
|
|||
func TestResourceManager(t *testing.T) { |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
|
|||
peerA, ia := makeInsecureMuxer(t) |
|||
_, ib := makeInsecureMuxer(t) |
|||
|
|||
ua, err := tptu.New(ia, new(mplex.Transport)) |
|||
require.NoError(t, err) |
|||
ta, err := NewTCPTransport(ua, nil) |
|||
require.NoError(t, err) |
|||
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
ub, err := tptu.New(ib, new(mplex.Transport)) |
|||
require.NoError(t, err) |
|||
rcmgr := mocknetwork.NewMockResourceManager(ctrl) |
|||
tb, err := NewTCPTransport(ub, rcmgr) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("success", func(t *testing.T) { |
|||
scope := mocknetwork.NewMockConnManagementScope(ctrl) |
|||
rcmgr.EXPECT().OpenConnection(network.DirOutbound, true).Return(scope, nil) |
|||
scope.EXPECT().SetPeer(peerA) |
|||
scope.EXPECT().PeerScope().Return(network.NullScope).AnyTimes() // called by the upgrader
|
|||
conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA) |
|||
require.NoError(t, err) |
|||
scope.EXPECT().Done() |
|||
defer conn.Close() |
|||
}) |
|||
|
|||
t.Run("connection denied", func(t *testing.T) { |
|||
rerr := errors.New("nope") |
|||
rcmgr.EXPECT().OpenConnection(network.DirOutbound, true).Return(nil, rerr) |
|||
_, err = tb.Dial(context.Background(), ln.Multiaddr(), peerA) |
|||
require.ErrorIs(t, err, rerr) |
|||
}) |
|||
|
|||
t.Run("peer denied", func(t *testing.T) { |
|||
scope := mocknetwork.NewMockConnManagementScope(ctrl) |
|||
rcmgr.EXPECT().OpenConnection(network.DirOutbound, true).Return(scope, nil) |
|||
rerr := errors.New("nope") |
|||
scope.EXPECT().SetPeer(peerA).Return(rerr) |
|||
scope.EXPECT().Done() |
|||
_, err = tb.Dial(context.Background(), ln.Multiaddr(), peerA) |
|||
require.ErrorIs(t, err, rerr) |
|||
}) |
|||
} |
|||
|
|||
func TestTcpTransportCantDialDNS(t *testing.T) { |
|||
for i := 0; i < 2; i++ { |
|||
dnsa, err := ma.NewMultiaddr("/dns4/example.com/tcp/1234") |
|||
require.NoError(t, err) |
|||
|
|||
var u transport.Upgrader |
|||
tpt, err := NewTCPTransport(u, nil) |
|||
require.NoError(t, err) |
|||
|
|||
if tpt.CanDial(dnsa) { |
|||
t.Fatal("shouldn't be able to dial dns") |
|||
} |
|||
|
|||
envReuseportVal = false |
|||
} |
|||
envReuseportVal = true |
|||
} |
|||
|
|||
func TestTcpTransportCantListenUtp(t *testing.T) { |
|||
for i := 0; i < 2; i++ { |
|||
utpa, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/utp") |
|||
require.NoError(t, err) |
|||
|
|||
var u transport.Upgrader |
|||
tpt, err := NewTCPTransport(u, nil) |
|||
require.NoError(t, err) |
|||
|
|||
_, err = tpt.Listen(utpa) |
|||
require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport") |
|||
|
|||
envReuseportVal = false |
|||
} |
|||
envReuseportVal = true |
|||
} |
|||
|
|||
func makeInsecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { |
|||
t.Helper() |
|||
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256) |
|||
require.NoError(t, err) |
|||
id, err := peer.IDFromPrivateKey(priv) |
|||
require.NoError(t, err) |
|||
var secMuxer csms.SSMuxer |
|||
secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv)) |
|||
return id, &secMuxer |
|||
} |
Loading…
Reference in new issue