mirror of https://github.com/libp2p/go-libp2p.git
Browse Source
* hole punching protocol * implement the new retry logic (opening a new stream per hole punch) * rename the holepunch.HolePunchService to holepunch.Service * remove stream handler when closing * remove misleading comment from protobuf, add link to spec * remove unnecessary netNotifiee.HolePunchService * simplify receiving from the timer channel * don't defer the closing of the hole punching stream * only accept hole punch streams from relayed peers * refactor error logging * fix incorrect counting of incoming hole punch attempts * rename HolePunch to DirectConnect * fix usage of the wait group to close the hole punch service * use a gc to the garbage collect the holepunch tracer * check directionality of connection when accepting the hole punch stream * add comment why we only call Connect once * move deduplication logic to the holepunch service Co-authored-by: aarshkshah1992 <aarshkshah1992@gmail.com>pull/1178/head
Marten Seemann
3 years ago
committed by
GitHub
10 changed files with 1547 additions and 9 deletions
@ -0,0 +1,404 @@ |
|||
package holepunch |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"fmt" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/host" |
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
"github.com/libp2p/go-libp2p-core/protocol" |
|||
pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" |
|||
"github.com/libp2p/go-libp2p/p2p/protocol/identify" |
|||
|
|||
logging "github.com/ipfs/go-log/v2" |
|||
"github.com/libp2p/go-msgio/protoio" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
) |
|||
|
|||
// Protocol is the libp2p protocol for Hole Punching.
|
|||
const Protocol protocol.ID = "/libp2p/dcutr" |
|||
|
|||
// StreamTimeout is the timeout for the hole punch protocol stream.
|
|||
var StreamTimeout = 1 * time.Minute |
|||
|
|||
// TODO Should we have options for these ?
|
|||
const ( |
|||
maxMsgSize = 4 * 1024 // 4K
|
|||
dialTimeout = 5 * time.Second |
|||
maxRetries = 3 |
|||
retryWait = 2 * time.Second |
|||
) |
|||
|
|||
var ( |
|||
log = logging.Logger("p2p-holepunch") |
|||
// ErrHolePunchActive is returned from DirectConnect when another hole punching attempt is currently running
|
|||
ErrHolePunchActive = errors.New("another hole punching attempt to this peer is active") |
|||
// ErrClosed is returned when the hole punching is closed
|
|||
ErrClosed = errors.New("hole punching service closing") |
|||
) |
|||
|
|||
// The Service is used to make direct connections with a peer via hole-punching.
|
|||
type Service struct { |
|||
ctx context.Context |
|||
ctxCancel context.CancelFunc |
|||
|
|||
ids *identify.IDService |
|||
host host.Host |
|||
|
|||
tracer *tracer |
|||
|
|||
closeMx sync.RWMutex |
|||
closed bool |
|||
refCount sync.WaitGroup |
|||
|
|||
// active hole punches for deduplicating
|
|||
activeMx sync.Mutex |
|||
active map[peer.ID]struct{} |
|||
} |
|||
|
|||
type Option func(*Service) error |
|||
|
|||
// NewService creates a new service that can be used for hole punching
|
|||
func NewService(h host.Host, ids *identify.IDService, opts ...Option) (*Service, error) { |
|||
if ids == nil { |
|||
return nil, errors.New("identify service can't be nil") |
|||
} |
|||
|
|||
ctx, cancel := context.WithCancel(context.Background()) |
|||
hs := &Service{ |
|||
ctx: ctx, |
|||
ctxCancel: cancel, |
|||
host: h, |
|||
ids: ids, |
|||
active: make(map[peer.ID]struct{}), |
|||
} |
|||
|
|||
for _, opt := range opts { |
|||
if err := opt(hs); err != nil { |
|||
cancel() |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
h.SetStreamHandler(Protocol, hs.handleNewStream) |
|||
h.Network().Notify((*netNotifiee)(hs)) |
|||
return hs, nil |
|||
} |
|||
|
|||
// Close closes the Hole Punch Service.
|
|||
func (hs *Service) Close() error { |
|||
hs.closeMx.Lock() |
|||
hs.closed = true |
|||
hs.closeMx.Unlock() |
|||
hs.tracer.Close() |
|||
hs.host.RemoveStreamHandler(Protocol) |
|||
hs.ctxCancel() |
|||
hs.refCount.Wait() |
|||
return nil |
|||
} |
|||
|
|||
// initiateHolePunch opens a new hole punching coordination stream,
|
|||
// exchanges the addresses and measures the RTT.
|
|||
func (hs *Service) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, time.Duration, error) { |
|||
hpCtx := network.WithUseTransient(hs.ctx, "hole-punch") |
|||
sCtx := network.WithNoDial(hpCtx, "hole-punch") |
|||
str, err := hs.host.NewStream(sCtx, rp, Protocol) |
|||
if err != nil { |
|||
return nil, 0, fmt.Errorf("failed to open hole-punching stream with peer %s: %w", rp, err) |
|||
} |
|||
defer str.Close() |
|||
str.SetDeadline(time.Now().Add(StreamTimeout)) |
|||
|
|||
w := protoio.NewDelimitedWriter(str) |
|||
|
|||
// send a CONNECT and start RTT measurement.
|
|||
msg := &pb.HolePunch{ |
|||
Type: pb.HolePunch_CONNECT.Enum(), |
|||
ObsAddrs: addrsToBytes(hs.ids.OwnObservedAddrs()), |
|||
} |
|||
|
|||
start := time.Now() |
|||
if err := w.WriteMsg(msg); err != nil { |
|||
str.Reset() |
|||
return nil, 0, err |
|||
} |
|||
|
|||
// wait for a CONNECT message from the remote peer
|
|||
rd := protoio.NewDelimitedReader(str, maxMsgSize) |
|||
msg.Reset() |
|||
if err := rd.ReadMsg(msg); err != nil { |
|||
str.Reset() |
|||
return nil, 0, fmt.Errorf("failed to read CONNECT message from remote peer: %w", err) |
|||
} |
|||
rtt := time.Since(start) |
|||
|
|||
if t := msg.GetType(); t != pb.HolePunch_CONNECT { |
|||
str.Reset() |
|||
return nil, 0, fmt.Errorf("expect CONNECT message, got %s", t) |
|||
} |
|||
|
|||
addrs := addrsFromBytes(msg.ObsAddrs) |
|||
|
|||
msg.Reset() |
|||
msg.Type = pb.HolePunch_SYNC.Enum() |
|||
if err := w.WriteMsg(msg); err != nil { |
|||
str.Reset() |
|||
return nil, 0, fmt.Errorf("failed to send SYNC message for hole punching: %w", err) |
|||
} |
|||
return addrs, rtt, nil |
|||
} |
|||
|
|||
func (hs *Service) beginDirectConnect(p peer.ID) error { |
|||
hs.closeMx.RLock() |
|||
defer hs.closeMx.RUnlock() |
|||
if hs.closed { |
|||
return ErrClosed |
|||
} |
|||
|
|||
hs.activeMx.Lock() |
|||
defer hs.activeMx.Unlock() |
|||
if _, ok := hs.active[p]; ok { |
|||
return ErrHolePunchActive |
|||
} |
|||
|
|||
hs.active[p] = struct{}{} |
|||
return nil |
|||
} |
|||
|
|||
// DirectConnect attempts to make a direct connection with a remote peer.
|
|||
// It first attempts a direct dial (if we have a public address of that peer), and then
|
|||
// coordinates a hole punch over the given relay connection.
|
|||
func (hs *Service) DirectConnect(p peer.ID) error { |
|||
log.Debugw("got inbound proxy conn from peer", p) |
|||
if err := hs.beginDirectConnect(p); err != nil { |
|||
return err |
|||
} |
|||
|
|||
defer func() { |
|||
hs.activeMx.Lock() |
|||
delete(hs.active, p) |
|||
hs.activeMx.Unlock() |
|||
}() |
|||
|
|||
return hs.directConnect(p) |
|||
} |
|||
|
|||
func (hs *Service) directConnect(rp peer.ID) error { |
|||
// short-circuit check to see if we already have a direct connection
|
|||
for _, c := range hs.host.Network().ConnsToPeer(rp) { |
|||
if !isRelayAddress(c.RemoteMultiaddr()) { |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// short-circuit hole punching if a direct dial works.
|
|||
// attempt a direct connection ONLY if we have a public address for the remote peer
|
|||
for _, a := range hs.host.Peerstore().Addrs(rp) { |
|||
if manet.IsPublicAddr(a) && !isRelayAddress(a) { |
|||
forceDirectConnCtx := network.WithForceDirectDial(hs.ctx, "hole-punching") |
|||
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) |
|||
|
|||
tstart := time.Now() |
|||
// This dials *all* public addresses from the peerstore.
|
|||
err := hs.host.Connect(dialCtx, peer.AddrInfo{ID: rp}) |
|||
dt := time.Since(tstart) |
|||
cancel() |
|||
|
|||
if err != nil { |
|||
hs.tracer.DirectDialFailed(rp, dt, err) |
|||
break |
|||
} |
|||
hs.tracer.DirectDialSuccessful(rp, dt) |
|||
log.Debugf("direct connection to peer %s successful, no need for a hole punch", rp.Pretty()) |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// hole punch
|
|||
for i := 0; i < maxRetries; i++ { |
|||
addrs, rtt, err := hs.initiateHolePunch(rp) |
|||
if err != nil { |
|||
log.Debugw("hole punching failed to peer", rp, "error", err) |
|||
hs.tracer.ProtocolError(rp, err) |
|||
return err |
|||
} |
|||
synTime := rtt / 2 |
|||
log.Debugf("peer RTT is %s; starting hole punch in %s", rtt, synTime) |
|||
|
|||
// wait for sync to reach the other peer and then punch a hole for it in our NAT
|
|||
// by attempting a connect to it.
|
|||
timer := time.NewTimer(synTime) |
|||
select { |
|||
case start := <-timer.C: |
|||
pi := peer.AddrInfo{ |
|||
ID: rp, |
|||
Addrs: addrs, |
|||
} |
|||
hs.tracer.StartHolePunch(rp, addrs, rtt) |
|||
err := hs.holePunchConnect(pi) |
|||
dt := time.Since(start) |
|||
hs.tracer.EndHolePunch(rp, dt, err) |
|||
if err == nil { |
|||
log.Debugw("hole punching with", rp, "successful after", dt) |
|||
return nil |
|||
} |
|||
case <-hs.ctx.Done(): |
|||
timer.Stop() |
|||
return hs.ctx.Err() |
|||
} |
|||
} |
|||
return fmt.Errorf("all retries for hole punch with peer %s failed", rp) |
|||
} |
|||
|
|||
func (hs *Service) incomingHolePunch(s network.Stream) (rtt time.Duration, addrs []ma.Multiaddr, err error) { |
|||
// sanity check: a hole punch request should only come from peers behind a relay
|
|||
if !isRelayAddress(s.Conn().RemoteMultiaddr()) { |
|||
return 0, nil, fmt.Errorf("received hole punch stream: %s", s.Conn().RemoteMultiaddr()) |
|||
} |
|||
|
|||
log.Debugw("got hole punch request from peer", s.Conn().RemotePeer().Pretty()) |
|||
_ = s.SetDeadline(time.Now().Add(StreamTimeout)) |
|||
wr := protoio.NewDelimitedWriter(s) |
|||
rd := protoio.NewDelimitedReader(s, maxMsgSize) |
|||
|
|||
// Read Connect message
|
|||
msg := new(pb.HolePunch) |
|||
if err := rd.ReadMsg(msg); err != nil { |
|||
return 0, nil, fmt.Errorf("failed to read message from initator: %w", err) |
|||
} |
|||
if t := msg.GetType(); t != pb.HolePunch_CONNECT { |
|||
return 0, nil, fmt.Errorf("expected CONNECT message from initiator but got %d", t) |
|||
} |
|||
obsDial := addrsFromBytes(msg.ObsAddrs) |
|||
|
|||
// Write CONNECT message
|
|||
msg.Reset() |
|||
msg.Type = pb.HolePunch_CONNECT.Enum() |
|||
msg.ObsAddrs = addrsToBytes(hs.ids.OwnObservedAddrs()) |
|||
tstart := time.Now() |
|||
if err := wr.WriteMsg(msg); err != nil { |
|||
return 0, nil, fmt.Errorf("failed to write CONNECT message to initator: %w", err) |
|||
} |
|||
|
|||
// Read SYNC message
|
|||
msg.Reset() |
|||
if err := rd.ReadMsg(msg); err != nil { |
|||
return 0, nil, fmt.Errorf("failed to read message from initator: %w", err) |
|||
} |
|||
if t := msg.GetType(); t != pb.HolePunch_SYNC { |
|||
return 0, nil, fmt.Errorf("expected SYNC message from initiator but got %d", t) |
|||
} |
|||
return time.Since(tstart), obsDial, nil |
|||
} |
|||
|
|||
func (hs *Service) handleNewStream(s network.Stream) { |
|||
// Check directionality of the underlying connection.
|
|||
// Peer A receives an inbound connection from peer B.
|
|||
// Peer A opens a new hole punch stream to peer B.
|
|||
// Peer B receives this stream, calling this function.
|
|||
// Peer B sees the underlying connection as an outbound connection.
|
|||
if s.Conn().Stat().Direction == network.DirInbound { |
|||
s.Reset() |
|||
return |
|||
} |
|||
rp := s.Conn().RemotePeer() |
|||
rtt, addrs, err := hs.incomingHolePunch(s) |
|||
if err != nil { |
|||
hs.tracer.ProtocolError(rp, err) |
|||
log.Debugw("error handling holepunching stream from", rp, "error", err) |
|||
s.Reset() |
|||
return |
|||
} |
|||
s.Close() |
|||
|
|||
// Hole punch now by forcing a connect
|
|||
pi := peer.AddrInfo{ |
|||
ID: rp, |
|||
Addrs: addrs, |
|||
} |
|||
hs.tracer.StartHolePunch(rp, addrs, rtt) |
|||
log.Debugw("starting hole punch with", rp) |
|||
start := time.Now() |
|||
err = hs.holePunchConnect(pi) |
|||
dt := time.Since(start) |
|||
hs.tracer.EndHolePunch(rp, dt, err) |
|||
if err != nil { |
|||
log.Debugw("hole punching peer", rp, "failed after", dt, "error", err) |
|||
} else { |
|||
log.Debugw("hole punching peer", rp, "succeeded after", dt) |
|||
} |
|||
} |
|||
|
|||
func (hs *Service) holePunchConnect(pi peer.AddrInfo) error { |
|||
holePunchCtx := network.WithSimultaneousConnect(hs.ctx, "hole-punching") |
|||
forceDirectConnCtx := network.WithForceDirectDial(holePunchCtx, "hole-punching") |
|||
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout) |
|||
defer cancel() |
|||
|
|||
hs.tracer.HolePunchAttempt(pi.ID) |
|||
err := hs.host.Connect(dialCtx, pi) |
|||
if err == nil { |
|||
log.Debugw("hole punch with peer", pi.ID, "successful") |
|||
} |
|||
log.Debugw("hole punch attempt with peer", pi.ID, "failed:", err) |
|||
return err |
|||
} |
|||
|
|||
func isRelayAddress(a ma.Multiaddr) bool { |
|||
_, err := a.ValueForProtocol(ma.P_CIRCUIT) |
|||
return err == nil |
|||
} |
|||
|
|||
func addrsToBytes(as []ma.Multiaddr) [][]byte { |
|||
bzs := make([][]byte, 0, len(as)) |
|||
for _, a := range as { |
|||
bzs = append(bzs, a.Bytes()) |
|||
} |
|||
return bzs |
|||
} |
|||
|
|||
func addrsFromBytes(bzs [][]byte) []ma.Multiaddr { |
|||
addrs := make([]ma.Multiaddr, 0, len(bzs)) |
|||
for _, bz := range bzs { |
|||
a, err := ma.NewMultiaddrBytes(bz) |
|||
if err == nil { |
|||
addrs = append(addrs, a) |
|||
} |
|||
} |
|||
return addrs |
|||
} |
|||
|
|||
type netNotifiee Service |
|||
|
|||
func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) { |
|||
hs := (*Service)(nn) |
|||
|
|||
// Hole punch if it's an inbound proxy connection.
|
|||
// If we already have a direct connection with the remote peer, this will be a no-op.
|
|||
if conn.Stat().Direction == network.DirInbound && isRelayAddress(conn.RemoteMultiaddr()) { |
|||
hs.refCount.Add(1) |
|||
go func() { |
|||
defer hs.refCount.Done() |
|||
|
|||
select { |
|||
// waiting for Identify here will allow us to access the peer's public and observed addresses
|
|||
// that we can dial to for a hole punch.
|
|||
case <-hs.ids.IdentifyWait(conn): |
|||
case <-hs.ctx.Done(): |
|||
} |
|||
|
|||
_ = hs.DirectConnect(conn.RemotePeer()) |
|||
}() |
|||
} |
|||
} |
|||
|
|||
func (nn *netNotifiee) Disconnected(_ network.Network, v network.Conn) {} |
|||
func (nn *netNotifiee) OpenedStream(n network.Network, v network.Stream) {} |
|||
func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} |
|||
func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} |
|||
func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {} |
@ -0,0 +1,373 @@ |
|||
package holepunch_test |
|||
|
|||
import ( |
|||
"context" |
|||
"net" |
|||
"sync" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p" |
|||
circuit "github.com/libp2p/go-libp2p-circuit" |
|||
"github.com/libp2p/go-libp2p-core/host" |
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
"github.com/libp2p/go-libp2p-core/peerstore" |
|||
"github.com/libp2p/go-libp2p-testing/race" |
|||
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch" |
|||
holepunch_pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb" |
|||
"github.com/libp2p/go-libp2p/p2p/protocol/identify" |
|||
"github.com/libp2p/go-msgio/protoio" |
|||
|
|||
ma "github.com/multiformats/go-multiaddr" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
|
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
type mockEventTracer struct { |
|||
mutex sync.Mutex |
|||
events []*holepunch.Event |
|||
} |
|||
|
|||
func (m *mockEventTracer) Trace(evt *holepunch.Event) { |
|||
m.mutex.Lock() |
|||
m.events = append(m.events, evt) |
|||
m.mutex.Unlock() |
|||
} |
|||
|
|||
func (m *mockEventTracer) getEvents() []*holepunch.Event { |
|||
m.mutex.Lock() |
|||
defer m.mutex.Unlock() |
|||
// copy the slice
|
|||
return append([]*holepunch.Event{}, m.events...) |
|||
} |
|||
|
|||
var _ holepunch.EventTracer = &mockEventTracer{} |
|||
|
|||
func TestNoHolePunchIfDirectConnExists(t *testing.T) { |
|||
tr := &mockEventTracer{} |
|||
h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr)) |
|||
defer h1.Close() |
|||
h2, _ := mkHostWithHolePunchSvc(t) |
|||
defer h2.Close() |
|||
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{ |
|||
ID: h2.ID(), |
|||
Addrs: h2.Addrs(), |
|||
})) |
|||
time.Sleep(50 * time.Millisecond) |
|||
nc1 := len(h1.Network().ConnsToPeer(h2.ID())) |
|||
require.GreaterOrEqual(t, nc1, 1) |
|||
nc2 := len(h2.Network().ConnsToPeer(h1.ID())) |
|||
require.GreaterOrEqual(t, nc2, 1) |
|||
|
|||
require.NoError(t, hps.DirectConnect(h2.ID())) |
|||
require.Equal(t, len(h1.Network().ConnsToPeer(h2.ID())), nc1) |
|||
require.Equal(t, len(h2.Network().ConnsToPeer(h1.ID())), nc2) |
|||
require.Empty(t, tr.getEvents()) |
|||
} |
|||
|
|||
func TestDirectDialWorks(t *testing.T) { |
|||
// mark all addresses as public
|
|||
cpy := manet.Private4 |
|||
manet.Private4 = []*net.IPNet{} |
|||
defer func() { manet.Private4 = cpy }() |
|||
|
|||
tr := &mockEventTracer{} |
|||
h1, h1ps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr)) |
|||
defer h1.Close() |
|||
h2, _ := mkHostWithHolePunchSvc(t) |
|||
defer h2.Close() |
|||
h2.RemoveStreamHandler(holepunch.Protocol) |
|||
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL) |
|||
|
|||
// try to hole punch without any connection and streams, if it works -> it's a direct connection
|
|||
require.Len(t, h1.Network().ConnsToPeer(h2.ID()), 0) |
|||
require.NoError(t, h1ps.DirectConnect(h2.ID())) |
|||
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1) |
|||
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1) |
|||
events := tr.getEvents() |
|||
require.Len(t, events, 1) |
|||
require.Equal(t, events[0].Type, holepunch.DirectDialEvtT) |
|||
} |
|||
|
|||
func TestEndToEndSimConnect(t *testing.T) { |
|||
tr := &mockEventTracer{} |
|||
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), true) |
|||
defer h1.Close() |
|||
defer h2.Close() |
|||
defer relay.Close() |
|||
|
|||
// wait till a direct connection is complete
|
|||
ensureDirectConn(t, h1, h2) |
|||
// ensure no hole-punching streams are open on either side
|
|||
ensureNoHolePunchingStream(t, h1, h2) |
|||
events := tr.getEvents() |
|||
require.Len(t, events, 3) |
|||
require.Equal(t, events[0].Type, holepunch.StartHolePunchEvtT) |
|||
require.Equal(t, events[1].Type, holepunch.HolePunchAttemptEvtT) |
|||
require.Equal(t, events[2].Type, holepunch.EndHolePunchEvtT) |
|||
} |
|||
|
|||
func TestFailuresOnInitiator(t *testing.T) { |
|||
tcs := map[string]struct { |
|||
rhandler func(s network.Stream) |
|||
errMsg string |
|||
holePunchTimeout time.Duration |
|||
}{ |
|||
"responder does NOT send a CONNECT message": { |
|||
rhandler: func(s network.Stream) { |
|||
wr := protoio.NewDelimitedWriter(s) |
|||
wr.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_SYNC.Enum()}) |
|||
}, |
|||
errMsg: "expect CONNECT message, got SYNC", |
|||
}, |
|||
"responder does NOT support protocol": { |
|||
rhandler: nil, |
|||
}, |
|||
"unable to READ CONNECT message from responder": { |
|||
rhandler: func(s network.Stream) { |
|||
s.Reset() |
|||
}, |
|||
errMsg: "failed to read CONNECT message", |
|||
}, |
|||
"responder does NOT reply within hole punch deadline": { |
|||
holePunchTimeout: 10 * time.Millisecond, |
|||
rhandler: func(s network.Stream) { |
|||
time.Sleep(5 * time.Second) |
|||
}, |
|||
errMsg: "i/o deadline reached", |
|||
}, |
|||
} |
|||
|
|||
for name, tc := range tcs { |
|||
t.Run(name, func(t *testing.T) { |
|||
if tc.holePunchTimeout != 0 { |
|||
cpy := holepunch.StreamTimeout |
|||
holepunch.StreamTimeout = tc.holePunchTimeout |
|||
defer func() { holepunch.StreamTimeout = cpy }() |
|||
} |
|||
|
|||
tr := &mockEventTracer{} |
|||
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false) |
|||
defer h1.Close() |
|||
defer h2.Close() |
|||
defer relay.Close() |
|||
hps := addHolePunchService(t, h2) |
|||
|
|||
if tc.rhandler != nil { |
|||
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler) |
|||
} else { |
|||
h1.RemoveStreamHandler(holepunch.Protocol) |
|||
} |
|||
|
|||
err := hps.DirectConnect(h1.ID()) |
|||
require.Error(t, err) |
|||
if tc.errMsg != "" { |
|||
require.Contains(t, err.Error(), tc.errMsg) |
|||
} |
|||
}) |
|||
|
|||
} |
|||
} |
|||
|
|||
func TestFailuresOnResponder(t *testing.T) { |
|||
tcs := map[string]struct { |
|||
initiator func(s network.Stream) |
|||
errMsg string |
|||
holePunchTimeout time.Duration |
|||
}{ |
|||
"initiator does NOT send a CONNECT message": { |
|||
initiator: func(s network.Stream) { |
|||
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_SYNC.Enum()}) |
|||
}, |
|||
errMsg: "expected CONNECT message", |
|||
}, |
|||
"initiator does NOT send a SYNC message after a Connect message": { |
|||
initiator: func(s network.Stream) { |
|||
w := protoio.NewDelimitedWriter(s) |
|||
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) |
|||
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) |
|||
}, |
|||
errMsg: "expected SYNC message", |
|||
}, |
|||
"initiator does NOT reply within hole punch deadline": { |
|||
holePunchTimeout: 10 * time.Millisecond, |
|||
initiator: func(s network.Stream) { |
|||
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) |
|||
time.Sleep(10 * time.Second) |
|||
}, |
|||
errMsg: "i/o deadline reached", |
|||
}, |
|||
} |
|||
|
|||
for name, tc := range tcs { |
|||
t.Run(name, func(t *testing.T) { |
|||
if tc.holePunchTimeout != 0 { |
|||
cpy := holepunch.StreamTimeout |
|||
holepunch.StreamTimeout = tc.holePunchTimeout |
|||
defer func() { holepunch.StreamTimeout = cpy }() |
|||
} |
|||
|
|||
tr := &mockEventTracer{} |
|||
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false) |
|||
defer h1.Close() |
|||
defer h2.Close() |
|||
defer relay.Close() |
|||
|
|||
s, err := h2.NewStream(context.Background(), h1.ID(), holepunch.Protocol) |
|||
require.NoError(t, err) |
|||
|
|||
go tc.initiator(s) |
|||
|
|||
getTracerError := func(tr *mockEventTracer) []string { |
|||
var errs []string |
|||
events := tr.getEvents() |
|||
for _, ev := range events { |
|||
if errEv, ok := ev.Evt.(*holepunch.ProtocolErrorEvt); ok { |
|||
errs = append(errs, errEv.Error) |
|||
} |
|||
} |
|||
return errs |
|||
} |
|||
|
|||
require.Eventually(t, func() bool { return len(getTracerError(tr)) > 0 }, 5*time.Second, 100*time.Millisecond) |
|||
errs := getTracerError(tr) |
|||
require.Len(t, errs, 1) |
|||
require.Contains(t, errs[0], tc.errMsg) |
|||
}) |
|||
|
|||
} |
|||
} |
|||
|
|||
func ensureNoHolePunchingStream(t *testing.T, h1, h2 host.Host) { |
|||
require.Eventually(t, func() bool { |
|||
for _, c := range h1.Network().ConnsToPeer(h2.ID()) { |
|||
for _, s := range c.GetStreams() { |
|||
if s.ID() == string(holepunch.Protocol) { |
|||
return false |
|||
} |
|||
} |
|||
} |
|||
return true |
|||
}, 5*time.Second, 50*time.Millisecond) |
|||
|
|||
require.Eventually(t, func() bool { |
|||
for _, c := range h2.Network().ConnsToPeer(h1.ID()) { |
|||
for _, s := range c.GetStreams() { |
|||
if s.ID() == string(holepunch.Protocol) { |
|||
return false |
|||
} |
|||
} |
|||
} |
|||
return true |
|||
}, 5*time.Second, 50*time.Millisecond) |
|||
} |
|||
|
|||
func ensureDirectConn(t *testing.T, h1, h2 host.Host) { |
|||
require.Eventually(t, func() bool { |
|||
for _, c := range h1.Network().ConnsToPeer(h2.ID()) { |
|||
if _, err := c.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT); err != nil { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
}, 5*time.Second, 50*time.Millisecond) |
|||
|
|||
require.Eventually(t, func() bool { |
|||
for _, c := range h2.Network().ConnsToPeer(h1.ID()) { |
|||
if _, err := c.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT); err != nil { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
}, 5*time.Second, 50*time.Millisecond) |
|||
} |
|||
|
|||
func mkHostWithStaticAutoRelay(t *testing.T, ctx context.Context, relay host.Host) host.Host { |
|||
if race.WithRace() { |
|||
t.Skip("modifying manet.Private4 is racy") |
|||
} |
|||
pi := peer.AddrInfo{ |
|||
ID: relay.ID(), |
|||
Addrs: relay.Addrs(), |
|||
} |
|||
|
|||
cpy := manet.Private4 |
|||
manet.Private4 = []*net.IPNet{} |
|||
defer func() { manet.Private4 = cpy }() |
|||
|
|||
h, err := libp2p.New(ctx, |
|||
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")), |
|||
libp2p.EnableRelay(), |
|||
libp2p.EnableAutoRelay(), |
|||
libp2p.ForceReachabilityPrivate(), |
|||
libp2p.StaticRelays([]peer.AddrInfo{pi}), |
|||
) |
|||
require.NoError(t, err) |
|||
|
|||
// wait till we have a relay addr
|
|||
require.Eventually(t, func() bool { |
|||
for _, a := range h.Addrs() { |
|||
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
}, 5*time.Second, 50*time.Millisecond) |
|||
return h |
|||
} |
|||
|
|||
func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) { |
|||
t.Helper() |
|||
h1, _ = mkHostWithHolePunchSvc(t, h1Opt) |
|||
var err error |
|||
relay, err = libp2p.New(context.Background(), |
|||
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")), |
|||
libp2p.EnableRelay(circuit.OptHop), |
|||
) |
|||
require.NoError(t, err) |
|||
h2 = mkHostWithStaticAutoRelay(t, context.Background(), relay) |
|||
if addHolePuncher { |
|||
hps = addHolePunchService(t, h2) |
|||
} |
|||
|
|||
// h1 has a relay addr
|
|||
// h2 should connect to the relay addr
|
|||
var raddr ma.Multiaddr |
|||
for _, a := range h2.Addrs() { |
|||
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil { |
|||
raddr = a |
|||
break |
|||
} |
|||
} |
|||
require.NotEmpty(t, raddr) |
|||
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{ |
|||
ID: h2.ID(), |
|||
Addrs: []ma.Multiaddr{raddr}, |
|||
})) |
|||
return |
|||
} |
|||
|
|||
func addHolePunchService(t *testing.T, h host.Host) *holepunch.Service { |
|||
t.Helper() |
|||
ids, err := identify.NewIDService(h) |
|||
require.NoError(t, err) |
|||
hps, err := holepunch.NewService(h, ids) |
|||
require.NoError(t, err) |
|||
return hps |
|||
} |
|||
|
|||
func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, *holepunch.Service) { |
|||
t.Helper() |
|||
h, err := libp2p.New( |
|||
context.Background(), |
|||
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")), |
|||
) |
|||
require.NoError(t, err) |
|||
ids, err := identify.NewIDService(h) |
|||
require.NoError(t, err) |
|||
hps, err := holepunch.NewService(h, ids, opts...) |
|||
require.NoError(t, err) |
|||
return h, hps |
|||
} |
@ -0,0 +1,11 @@ |
|||
PB = $(wildcard *.proto) |
|||
GO = $(PB:.proto=.pb.go) |
|||
|
|||
all: $(GO) |
|||
|
|||
%.pb.go: %.proto |
|||
protoc --proto_path=$(GOPATH)/src:. --gogofast_out=. $< |
|||
|
|||
clean: |
|||
rm -f *.pb.go |
|||
rm -f *.go |
@ -0,0 +1,414 @@ |
|||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
|||
// source: holepunch.proto
|
|||
|
|||
package holepunch_pb |
|||
|
|||
import ( |
|||
fmt "fmt" |
|||
proto "github.com/gogo/protobuf/proto" |
|||
io "io" |
|||
math "math" |
|||
math_bits "math/bits" |
|||
) |
|||
|
|||
// Reference imports to suppress errors if they are not otherwise used.
|
|||
var _ = proto.Marshal |
|||
var _ = fmt.Errorf |
|||
var _ = math.Inf |
|||
|
|||
// This is a compile-time assertion to ensure that this generated file
|
|||
// is compatible with the proto package it is being compiled against.
|
|||
// A compilation error at this line likely means your copy of the
|
|||
// proto package needs to be updated.
|
|||
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
|
|||
|
|||
type HolePunch_Type int32 |
|||
|
|||
const ( |
|||
HolePunch_CONNECT HolePunch_Type = 100 |
|||
HolePunch_SYNC HolePunch_Type = 300 |
|||
) |
|||
|
|||
var HolePunch_Type_name = map[int32]string{ |
|||
100: "CONNECT", |
|||
300: "SYNC", |
|||
} |
|||
|
|||
var HolePunch_Type_value = map[string]int32{ |
|||
"CONNECT": 100, |
|||
"SYNC": 300, |
|||
} |
|||
|
|||
func (x HolePunch_Type) Enum() *HolePunch_Type { |
|||
p := new(HolePunch_Type) |
|||
*p = x |
|||
return p |
|||
} |
|||
|
|||
func (x HolePunch_Type) String() string { |
|||
return proto.EnumName(HolePunch_Type_name, int32(x)) |
|||
} |
|||
|
|||
func (x *HolePunch_Type) UnmarshalJSON(data []byte) error { |
|||
value, err := proto.UnmarshalJSONEnum(HolePunch_Type_value, data, "HolePunch_Type") |
|||
if err != nil { |
|||
return err |
|||
} |
|||
*x = HolePunch_Type(value) |
|||
return nil |
|||
} |
|||
|
|||
func (HolePunch_Type) EnumDescriptor() ([]byte, []int) { |
|||
return fileDescriptor_290ddea0f23ef64a, []int{0, 0} |
|||
} |
|||
|
|||
type HolePunch struct { |
|||
Type *HolePunch_Type `protobuf:"varint,1,opt,name=type,enum=holepunch.pb.HolePunch_Type" json:"type,omitempty"` |
|||
// For hole punching, we'll send some additional observed addresses to the remote peer
|
|||
// that could have been filtered by the Host address factory (for example: AutoRelay removes all public addresses if peer has private reachability).
|
|||
// This is a hack!
|
|||
// We plan to have a better address discovery and advertisement mechanism in the future.
|
|||
// See https://github.com/libp2p/go-libp2p-autonat/pull/98
|
|||
ObsAddrs [][]byte `protobuf:"bytes,2,rep,name=ObsAddrs" json:"ObsAddrs,omitempty"` |
|||
XXX_NoUnkeyedLiteral struct{} `json:"-"` |
|||
XXX_unrecognized []byte `json:"-"` |
|||
XXX_sizecache int32 `json:"-"` |
|||
} |
|||
|
|||
func (m *HolePunch) Reset() { *m = HolePunch{} } |
|||
func (m *HolePunch) String() string { return proto.CompactTextString(m) } |
|||
func (*HolePunch) ProtoMessage() {} |
|||
func (*HolePunch) Descriptor() ([]byte, []int) { |
|||
return fileDescriptor_290ddea0f23ef64a, []int{0} |
|||
} |
|||
func (m *HolePunch) XXX_Unmarshal(b []byte) error { |
|||
return m.Unmarshal(b) |
|||
} |
|||
func (m *HolePunch) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { |
|||
if deterministic { |
|||
return xxx_messageInfo_HolePunch.Marshal(b, m, deterministic) |
|||
} else { |
|||
b = b[:cap(b)] |
|||
n, err := m.MarshalToSizedBuffer(b) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return b[:n], nil |
|||
} |
|||
} |
|||
func (m *HolePunch) XXX_Merge(src proto.Message) { |
|||
xxx_messageInfo_HolePunch.Merge(m, src) |
|||
} |
|||
func (m *HolePunch) XXX_Size() int { |
|||
return m.Size() |
|||
} |
|||
func (m *HolePunch) XXX_DiscardUnknown() { |
|||
xxx_messageInfo_HolePunch.DiscardUnknown(m) |
|||
} |
|||
|
|||
var xxx_messageInfo_HolePunch proto.InternalMessageInfo |
|||
|
|||
func (m *HolePunch) GetType() HolePunch_Type { |
|||
if m != nil && m.Type != nil { |
|||
return *m.Type |
|||
} |
|||
return HolePunch_CONNECT |
|||
} |
|||
|
|||
func (m *HolePunch) GetObsAddrs() [][]byte { |
|||
if m != nil { |
|||
return m.ObsAddrs |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func init() { |
|||
proto.RegisterEnum("holepunch.pb.HolePunch_Type", HolePunch_Type_name, HolePunch_Type_value) |
|||
proto.RegisterType((*HolePunch)(nil), "holepunch.pb.HolePunch") |
|||
} |
|||
|
|||
func init() { proto.RegisterFile("holepunch.proto", fileDescriptor_290ddea0f23ef64a) } |
|||
|
|||
var fileDescriptor_290ddea0f23ef64a = []byte{ |
|||
// 153 bytes of a gzipped FileDescriptorProto
|
|||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcf, 0xc8, 0xcf, 0x49, |
|||
0x2d, 0x28, 0xcd, 0x4b, 0xce, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x41, 0x12, 0x48, |
|||
0x52, 0xaa, 0xe4, 0xe2, 0xf4, 0xc8, 0xcf, 0x49, 0x0d, 0x00, 0xf1, 0x85, 0x0c, 0xb8, 0x58, 0x4a, |
|||
0x2a, 0x0b, 0x52, 0x25, 0x18, 0x15, 0x18, 0x35, 0xf8, 0x8c, 0x64, 0xf4, 0x90, 0x55, 0xea, 0xc1, |
|||
0x95, 0xe9, 0x85, 0x54, 0x16, 0xa4, 0x06, 0x81, 0x55, 0x0a, 0x49, 0x71, 0x71, 0xf8, 0x27, 0x15, |
|||
0x3b, 0xa6, 0xa4, 0x14, 0x15, 0x4b, 0x30, 0x29, 0x30, 0x6b, 0xf0, 0x04, 0xc1, 0xf9, 0x4a, 0x72, |
|||
0x5c, 0x2c, 0x20, 0x95, 0x42, 0xdc, 0x5c, 0xec, 0xce, 0xfe, 0x7e, 0x7e, 0xae, 0xce, 0x21, 0x02, |
|||
0x29, 0x42, 0x9c, 0x5c, 0x2c, 0xc1, 0x91, 0x7e, 0xce, 0x02, 0x6b, 0x98, 0x9c, 0x78, 0x4e, 0x3c, |
|||
0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0x46, 0x40, 0x00, 0x00, 0x00, 0xff, |
|||
0xff, 0x62, 0xf4, 0xc8, 0x7c, 0xa8, 0x00, 0x00, 0x00, |
|||
} |
|||
|
|||
func (m *HolePunch) Marshal() (dAtA []byte, err error) { |
|||
size := m.Size() |
|||
dAtA = make([]byte, size) |
|||
n, err := m.MarshalToSizedBuffer(dAtA[:size]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return dAtA[:n], nil |
|||
} |
|||
|
|||
func (m *HolePunch) MarshalTo(dAtA []byte) (int, error) { |
|||
size := m.Size() |
|||
return m.MarshalToSizedBuffer(dAtA[:size]) |
|||
} |
|||
|
|||
func (m *HolePunch) MarshalToSizedBuffer(dAtA []byte) (int, error) { |
|||
i := len(dAtA) |
|||
_ = i |
|||
var l int |
|||
_ = l |
|||
if m.XXX_unrecognized != nil { |
|||
i -= len(m.XXX_unrecognized) |
|||
copy(dAtA[i:], m.XXX_unrecognized) |
|||
} |
|||
if len(m.ObsAddrs) > 0 { |
|||
for iNdEx := len(m.ObsAddrs) - 1; iNdEx >= 0; iNdEx-- { |
|||
i -= len(m.ObsAddrs[iNdEx]) |
|||
copy(dAtA[i:], m.ObsAddrs[iNdEx]) |
|||
i = encodeVarintHolepunch(dAtA, i, uint64(len(m.ObsAddrs[iNdEx]))) |
|||
i-- |
|||
dAtA[i] = 0x12 |
|||
} |
|||
} |
|||
if m.Type != nil { |
|||
i = encodeVarintHolepunch(dAtA, i, uint64(*m.Type)) |
|||
i-- |
|||
dAtA[i] = 0x8 |
|||
} |
|||
return len(dAtA) - i, nil |
|||
} |
|||
|
|||
func encodeVarintHolepunch(dAtA []byte, offset int, v uint64) int { |
|||
offset -= sovHolepunch(v) |
|||
base := offset |
|||
for v >= 1<<7 { |
|||
dAtA[offset] = uint8(v&0x7f | 0x80) |
|||
v >>= 7 |
|||
offset++ |
|||
} |
|||
dAtA[offset] = uint8(v) |
|||
return base |
|||
} |
|||
func (m *HolePunch) Size() (n int) { |
|||
if m == nil { |
|||
return 0 |
|||
} |
|||
var l int |
|||
_ = l |
|||
if m.Type != nil { |
|||
n += 1 + sovHolepunch(uint64(*m.Type)) |
|||
} |
|||
if len(m.ObsAddrs) > 0 { |
|||
for _, b := range m.ObsAddrs { |
|||
l = len(b) |
|||
n += 1 + l + sovHolepunch(uint64(l)) |
|||
} |
|||
} |
|||
if m.XXX_unrecognized != nil { |
|||
n += len(m.XXX_unrecognized) |
|||
} |
|||
return n |
|||
} |
|||
|
|||
func sovHolepunch(x uint64) (n int) { |
|||
return (math_bits.Len64(x|1) + 6) / 7 |
|||
} |
|||
func sozHolepunch(x uint64) (n int) { |
|||
return sovHolepunch(uint64((x << 1) ^ uint64((int64(x) >> 63)))) |
|||
} |
|||
func (m *HolePunch) Unmarshal(dAtA []byte) error { |
|||
l := len(dAtA) |
|||
iNdEx := 0 |
|||
for iNdEx < l { |
|||
preIndex := iNdEx |
|||
var wire uint64 |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return ErrIntOverflowHolepunch |
|||
} |
|||
if iNdEx >= l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
wire |= uint64(b&0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
fieldNum := int32(wire >> 3) |
|||
wireType := int(wire & 0x7) |
|||
if wireType == 4 { |
|||
return fmt.Errorf("proto: HolePunch: wiretype end group for non-group") |
|||
} |
|||
if fieldNum <= 0 { |
|||
return fmt.Errorf("proto: HolePunch: illegal tag %d (wire type %d)", fieldNum, wire) |
|||
} |
|||
switch fieldNum { |
|||
case 1: |
|||
if wireType != 0 { |
|||
return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) |
|||
} |
|||
var v HolePunch_Type |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return ErrIntOverflowHolepunch |
|||
} |
|||
if iNdEx >= l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
v |= HolePunch_Type(b&0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
m.Type = &v |
|||
case 2: |
|||
if wireType != 2 { |
|||
return fmt.Errorf("proto: wrong wireType = %d for field ObsAddrs", wireType) |
|||
} |
|||
var byteLen int |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return ErrIntOverflowHolepunch |
|||
} |
|||
if iNdEx >= l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
byteLen |= int(b&0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
if byteLen < 0 { |
|||
return ErrInvalidLengthHolepunch |
|||
} |
|||
postIndex := iNdEx + byteLen |
|||
if postIndex < 0 { |
|||
return ErrInvalidLengthHolepunch |
|||
} |
|||
if postIndex > l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
m.ObsAddrs = append(m.ObsAddrs, make([]byte, postIndex-iNdEx)) |
|||
copy(m.ObsAddrs[len(m.ObsAddrs)-1], dAtA[iNdEx:postIndex]) |
|||
iNdEx = postIndex |
|||
default: |
|||
iNdEx = preIndex |
|||
skippy, err := skipHolepunch(dAtA[iNdEx:]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if skippy < 0 { |
|||
return ErrInvalidLengthHolepunch |
|||
} |
|||
if (iNdEx + skippy) < 0 { |
|||
return ErrInvalidLengthHolepunch |
|||
} |
|||
if (iNdEx + skippy) > l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) |
|||
iNdEx += skippy |
|||
} |
|||
} |
|||
|
|||
if iNdEx > l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
return nil |
|||
} |
|||
func skipHolepunch(dAtA []byte) (n int, err error) { |
|||
l := len(dAtA) |
|||
iNdEx := 0 |
|||
depth := 0 |
|||
for iNdEx < l { |
|||
var wire uint64 |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return 0, ErrIntOverflowHolepunch |
|||
} |
|||
if iNdEx >= l { |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
wire |= (uint64(b) & 0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
wireType := int(wire & 0x7) |
|||
switch wireType { |
|||
case 0: |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return 0, ErrIntOverflowHolepunch |
|||
} |
|||
if iNdEx >= l { |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
iNdEx++ |
|||
if dAtA[iNdEx-1] < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
case 1: |
|||
iNdEx += 8 |
|||
case 2: |
|||
var length int |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return 0, ErrIntOverflowHolepunch |
|||
} |
|||
if iNdEx >= l { |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
length |= (int(b) & 0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
if length < 0 { |
|||
return 0, ErrInvalidLengthHolepunch |
|||
} |
|||
iNdEx += length |
|||
case 3: |
|||
depth++ |
|||
case 4: |
|||
if depth == 0 { |
|||
return 0, ErrUnexpectedEndOfGroupHolepunch |
|||
} |
|||
depth-- |
|||
case 5: |
|||
iNdEx += 4 |
|||
default: |
|||
return 0, fmt.Errorf("proto: illegal wireType %d", wireType) |
|||
} |
|||
if iNdEx < 0 { |
|||
return 0, ErrInvalidLengthHolepunch |
|||
} |
|||
if depth == 0 { |
|||
return iNdEx, nil |
|||
} |
|||
} |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
|
|||
var ( |
|||
ErrInvalidLengthHolepunch = fmt.Errorf("proto: negative length found during unmarshaling") |
|||
ErrIntOverflowHolepunch = fmt.Errorf("proto: integer overflow") |
|||
ErrUnexpectedEndOfGroupHolepunch = fmt.Errorf("proto: unexpected end of group") |
|||
) |
@ -0,0 +1,14 @@ |
|||
syntax = "proto2"; |
|||
|
|||
package holepunch.pb; |
|||
|
|||
// spec: https://github.com/libp2p/specs/blob/master/relay/DCUtR.md |
|||
message HolePunch { |
|||
enum Type { |
|||
CONNECT = 100; |
|||
SYNC = 300; |
|||
} |
|||
|
|||
optional Type type=1; |
|||
repeated bytes ObsAddrs = 2; |
|||
} |
@ -0,0 +1,252 @@ |
|||
package holepunch |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
|
|||
ma "github.com/multiformats/go-multiaddr" |
|||
) |
|||
|
|||
const ( |
|||
tracerGCInterval = 2 * time.Minute |
|||
tracerCacheDuration = 5 * time.Minute |
|||
) |
|||
|
|||
// WithTracer is a Service option that enables hole punching tracing
|
|||
func WithTracer(tr EventTracer) Option { |
|||
return func(hps *Service) error { |
|||
t := &tracer{ |
|||
tr: tr, |
|||
self: hps.host.ID(), |
|||
peers: make(map[peer.ID]struct { |
|||
counter int |
|||
last time.Time |
|||
}), |
|||
} |
|||
t.refCount.Add(1) |
|||
t.ctx, t.ctxCancel = context.WithCancel(context.Background()) |
|||
go t.gc() |
|||
hps.tracer = t |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
type tracer struct { |
|||
tr EventTracer |
|||
self peer.ID |
|||
|
|||
refCount sync.WaitGroup |
|||
ctx context.Context |
|||
ctxCancel context.CancelFunc |
|||
|
|||
mutex sync.Mutex |
|||
peers map[peer.ID]struct { |
|||
counter int |
|||
last time.Time |
|||
} |
|||
} |
|||
|
|||
type EventTracer interface { |
|||
Trace(evt *Event) |
|||
} |
|||
|
|||
type Event struct { |
|||
Timestamp int64 // UNIX nanos
|
|||
Peer peer.ID // local peer ID
|
|||
Remote peer.ID // remote peer ID
|
|||
Type string // event type
|
|||
Evt interface{} // the actual event
|
|||
} |
|||
|
|||
// Event Types
|
|||
const ( |
|||
DirectDialEvtT = "DirectDial" |
|||
ProtocolErrorEvtT = "ProtocolError" |
|||
StartHolePunchEvtT = "StartHolePunch" |
|||
EndHolePunchEvtT = "EndHolePunch" |
|||
HolePunchAttemptEvtT = "HolePunchAttempt" |
|||
) |
|||
|
|||
// Event Objects
|
|||
type DirectDialEvt struct { |
|||
Success bool |
|||
EllapsedTime time.Duration |
|||
Error string `json:",omitempty"` |
|||
} |
|||
|
|||
type ProtocolErrorEvt struct { |
|||
Error string |
|||
} |
|||
|
|||
type StartHolePunchEvt struct { |
|||
RemoteAddrs []string |
|||
RTT time.Duration |
|||
} |
|||
|
|||
type EndHolePunchEvt struct { |
|||
Success bool |
|||
EllapsedTime time.Duration |
|||
Error string `json:",omitempty"` |
|||
} |
|||
|
|||
type HolePunchAttemptEvt struct { |
|||
Attempt int |
|||
} |
|||
|
|||
// tracer interface
|
|||
func (t *tracer) DirectDialSuccessful(p peer.ID, dt time.Duration) { |
|||
if t == nil { |
|||
return |
|||
} |
|||
|
|||
t.tr.Trace(&Event{ |
|||
Timestamp: time.Now().UnixNano(), |
|||
Peer: t.self, |
|||
Remote: p, |
|||
Type: DirectDialEvtT, |
|||
Evt: &DirectDialEvt{ |
|||
Success: true, |
|||
EllapsedTime: dt, |
|||
}, |
|||
}) |
|||
} |
|||
|
|||
func (t *tracer) DirectDialFailed(p peer.ID, dt time.Duration, err error) { |
|||
if t == nil { |
|||
return |
|||
} |
|||
|
|||
t.tr.Trace(&Event{ |
|||
Timestamp: time.Now().UnixNano(), |
|||
Peer: t.self, |
|||
Remote: p, |
|||
Type: DirectDialEvtT, |
|||
Evt: &DirectDialEvt{ |
|||
Success: false, |
|||
EllapsedTime: dt, |
|||
Error: err.Error(), |
|||
}, |
|||
}) |
|||
} |
|||
|
|||
func (t *tracer) ProtocolError(p peer.ID, err error) { |
|||
if t == nil { |
|||
return |
|||
} |
|||
|
|||
t.tr.Trace(&Event{ |
|||
Timestamp: time.Now().UnixNano(), |
|||
Peer: t.self, |
|||
Remote: p, |
|||
Type: ProtocolErrorEvtT, |
|||
Evt: &ProtocolErrorEvt{ |
|||
Error: err.Error(), |
|||
}, |
|||
}) |
|||
} |
|||
|
|||
func (t *tracer) StartHolePunch(p peer.ID, obsAddrs []ma.Multiaddr, rtt time.Duration) { |
|||
if t == nil { |
|||
return |
|||
} |
|||
|
|||
addrs := make([]string, 0, len(obsAddrs)) |
|||
for _, a := range obsAddrs { |
|||
addrs = append(addrs, a.String()) |
|||
} |
|||
|
|||
t.tr.Trace(&Event{ |
|||
Timestamp: time.Now().UnixNano(), |
|||
Peer: t.self, |
|||
Remote: p, |
|||
Type: StartHolePunchEvtT, |
|||
Evt: &StartHolePunchEvt{ |
|||
RemoteAddrs: addrs, |
|||
RTT: rtt, |
|||
}, |
|||
}) |
|||
} |
|||
|
|||
func (t *tracer) EndHolePunch(p peer.ID, dt time.Duration, err error) { |
|||
if t == nil { |
|||
return |
|||
} |
|||
|
|||
evt := &EndHolePunchEvt{ |
|||
Success: err == nil, |
|||
EllapsedTime: dt, |
|||
} |
|||
if err != nil { |
|||
evt.Error = err.Error() |
|||
} |
|||
|
|||
t.tr.Trace(&Event{ |
|||
Timestamp: time.Now().UnixNano(), |
|||
Peer: t.self, |
|||
Remote: p, |
|||
Type: EndHolePunchEvtT, |
|||
Evt: evt, |
|||
}) |
|||
} |
|||
|
|||
func (t *tracer) HolePunchAttempt(p peer.ID) { |
|||
if t == nil { |
|||
return |
|||
} |
|||
|
|||
now := time.Now() |
|||
t.mutex.Lock() |
|||
attempt := t.peers[p] |
|||
attempt.counter++ |
|||
counter := attempt.counter |
|||
attempt.last = now |
|||
t.peers[p] = attempt |
|||
t.mutex.Unlock() |
|||
|
|||
t.tr.Trace(&Event{ |
|||
Timestamp: now.UnixNano(), |
|||
Peer: t.self, |
|||
Remote: p, |
|||
Type: HolePunchAttemptEvtT, |
|||
Evt: &HolePunchAttemptEvt{Attempt: counter}, |
|||
}) |
|||
} |
|||
|
|||
func (t *tracer) gc() { |
|||
defer func() { |
|||
fmt.Println("done") |
|||
t.refCount.Done() |
|||
}() |
|||
|
|||
timer := time.NewTicker(tracerGCInterval) |
|||
defer timer.Stop() |
|||
|
|||
for { |
|||
select { |
|||
case now := <-timer.C: |
|||
t.mutex.Lock() |
|||
for id, entry := range t.peers { |
|||
if entry.last.Before(now.Add(-tracerCacheDuration)) { |
|||
delete(t.peers, id) |
|||
} |
|||
} |
|||
t.mutex.Unlock() |
|||
case <-t.ctx.Done(): |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (t *tracer) Close() error { |
|||
if t == nil { |
|||
return nil |
|||
} |
|||
|
|||
t.ctxCancel() |
|||
t.refCount.Wait() |
|||
return nil |
|||
} |
Loading…
Reference in new issue