Browse Source

Merge pull request #2248 from libp2p/simplify-nat

move NAT mapping logic out of the host, add tests for NAT handling
pull/2266/head
Marten Seemann 2 years ago
committed by GitHub
parent
commit
410248e111
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 89
      p2p/host/basic/basic_host.go
  2. 92
      p2p/host/basic/mock_nat_test.go
  3. 6
      p2p/host/basic/mocks.go
  4. 195
      p2p/host/basic/natmgr.go
  5. 108
      p2p/host/basic/natmgr_test.go
  6. 119
      p2p/net/nat/mapping.go
  7. 124
      p2p/net/nat/mock_nat_test.go
  8. 210
      p2p/net/nat/nat.go
  9. 69
      p2p/net/nat/nat_test.go

89
p2p/host/basic/basic_host.go

@ -25,7 +25,6 @@ import (
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/host/pstoremanager"
"github.com/libp2p/go-libp2p/p2p/host/relaysvc"
inat "github.com/libp2p/go-libp2p/p2p/net/nat"
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
@ -858,99 +857,19 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr {
finalAddrs = dedupAddrs(finalAddrs)
var natMappings []inat.Mapping
// natmgr is nil if we do not use nat option;
// h.natmgr.NAT() is nil if not ready, or no nat is available.
if h.natmgr != nil && h.natmgr.NAT() != nil {
natMappings = h.natmgr.NAT().Mappings()
}
if len(natMappings) > 0 {
if h.natmgr != nil {
// We have successfully mapped ports on our NAT. Use those
// instead of observed addresses (mostly).
// First, generate a mapping table.
// protocol -> internal port -> external addr
ports := make(map[string]map[int]net.Addr)
for _, m := range natMappings {
addr, err := m.ExternalAddr()
if err != nil {
// mapping not ready yet.
continue
}
protoPorts, ok := ports[m.Protocol()]
if !ok {
protoPorts = make(map[int]net.Addr)
ports[m.Protocol()] = protoPorts
}
protoPorts[m.InternalPort()] = addr
}
// Next, apply this mapping to our addresses.
for _, listen := range listenAddrs {
found := false
transport, rest := ma.SplitFunc(listen, func(c ma.Component) bool {
if found {
return true
}
switch c.Protocol().Code {
case ma.P_TCP, ma.P_UDP:
found = true
}
return false
})
if !manet.IsThinWaist(transport) {
continue
}
naddr, err := manet.ToNetAddr(transport)
if err != nil {
log.Error("error parsing net multiaddr %q: %s", transport, err)
extMaddr := h.natmgr.GetMapping(listen)
if extMaddr == nil {
// not mapped
continue
}
var (
ip net.IP
iport int
protocol string
)
switch naddr := naddr.(type) {
case *net.TCPAddr:
ip = naddr.IP
iport = naddr.Port
protocol = "tcp"
case *net.UDPAddr:
ip = naddr.IP
iport = naddr.Port
protocol = "udp"
default:
continue
}
if !ip.IsGlobalUnicast() && !ip.IsUnspecified() {
// We only map global unicast & unspecified addresses ports.
// Not broadcast, multicast, etc.
continue
}
mappedAddr, ok := ports[protocol][iport]
if !ok {
// Not mapped.
continue
}
mappedMaddr, err := manet.FromNetAddr(mappedAddr)
if err != nil {
log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err)
continue
}
extMaddr := mappedMaddr
if rest != nil {
extMaddr = ma.Join(extMaddr, rest)
}
// if the router reported a sane address
if !manet.IsIPUnspecified(extMaddr) {
// Add in the mapped addr.

92
p2p/host/basic/mock_nat_test.go

@ -0,0 +1,92 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/libp2p/go-libp2p/p2p/host/basic (interfaces: NAT)
// Package basichost is a generated GoMock package.
package basichost
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockNAT is a mock of NAT interface.
type MockNAT struct {
ctrl *gomock.Controller
recorder *MockNATMockRecorder
}
// MockNATMockRecorder is the mock recorder for MockNAT.
type MockNATMockRecorder struct {
mock *MockNAT
}
// NewMockNAT creates a new mock instance.
func NewMockNAT(ctrl *gomock.Controller) *MockNAT {
mock := &MockNAT{ctrl: ctrl}
mock.recorder = &MockNATMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockNAT) EXPECT() *MockNATMockRecorder {
return m.recorder
}
// AddMapping mocks base method.
func (m *MockNAT) AddMapping(arg0 string, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddMapping", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// AddMapping indicates an expected call of AddMapping.
func (mr *MockNATMockRecorder) AddMapping(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMapping", reflect.TypeOf((*MockNAT)(nil).AddMapping), arg0, arg1)
}
// Close mocks base method.
func (m *MockNAT) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockNATMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNAT)(nil).Close))
}
// GetMapping mocks base method.
func (m *MockNAT) GetMapping(arg0 string, arg1 int) (netip.AddrPort, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMapping", arg0, arg1)
ret0, _ := ret[0].(netip.AddrPort)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetMapping indicates an expected call of GetMapping.
func (mr *MockNATMockRecorder) GetMapping(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMapping", reflect.TypeOf((*MockNAT)(nil).GetMapping), arg0, arg1)
}
// RemoveMapping mocks base method.
func (m *MockNAT) RemoveMapping(arg0 string, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveMapping", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveMapping indicates an expected call of RemoveMapping.
func (mr *MockNATMockRecorder) RemoveMapping(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMapping", reflect.TypeOf((*MockNAT)(nil).RemoveMapping), arg0, arg1)
}

6
p2p/host/basic/mocks.go

@ -0,0 +1,6 @@
//go:build gomock || generate
package basichost
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT"
type NAT nat

195
p2p/host/basic/natmgr.go

@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"net/netip"
"strconv"
"sync"
"time"
@ -12,24 +13,37 @@ import (
inat "github.com/libp2p/go-libp2p/p2p/net/nat"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// NATManager is a simple interface to manage NAT devices.
// It listens Listen and ListenClose notifications from the network.Network,
// and tries to obtain port mappings for those.
type NATManager interface {
// NAT gets the NAT device managed by the NAT manager.
NAT() *inat.NAT
// Ready receives a notification when the NAT device is ready for use.
Ready() <-chan struct{}
GetMapping(ma.Multiaddr) ma.Multiaddr
io.Closer
}
// NewNATManager creates a NAT manager.
func NewNATManager(net network.Network) NATManager {
return newNatManager(net)
return newNATManager(net)
}
type entry struct {
protocol string
port int
}
type nat interface {
AddMapping(protocol string, port int) error
RemoveMapping(protocol string, port int) error
GetMapping(protocol string, port int) (netip.AddrPort, bool)
io.Closer
}
// so we can mock it in tests
var discoverNAT = func(ctx context.Context) (nat, error) { return inat.DiscoverNAT(ctx) }
// natManager takes care of adding + removing port mappings to the nat.
// Initialized with the host if it has a NATPortMap option enabled.
// natManager receives signals from the network, and check on nat mappings:
@ -39,22 +53,23 @@ func NewNATManager(net network.Network) NATManager {
type natManager struct {
net network.Network
natMx sync.RWMutex
nat *inat.NAT
nat nat
ready chan struct{} // closed once the nat is ready to process port mappings
syncFlag chan struct{}
syncFlag chan struct{} // cap: 1
tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function
refCount sync.WaitGroup
ctxCancel context.CancelFunc
}
func newNatManager(net network.Network) *natManager {
func newNATManager(net network.Network) *natManager {
ctx, cancel := context.WithCancel(context.Background())
nmgr := &natManager{
net: net,
ready: make(chan struct{}),
syncFlag: make(chan struct{}, 1),
ctxCancel: cancel,
tracked: make(map[entry]bool),
}
nmgr.refCount.Add(1)
go nmgr.background(ctx)
@ -69,36 +84,29 @@ func (nmgr *natManager) Close() error {
return nil
}
// Ready returns a channel which will be closed when the NAT has been found
// and is ready to be used, or the search process is done.
func (nmgr *natManager) Ready() <-chan struct{} {
return nmgr.ready
}
func (nmgr *natManager) background(ctx context.Context) {
defer nmgr.refCount.Done()
defer func() {
nmgr.natMx.Lock()
defer nmgr.natMx.Unlock()
if nmgr.nat != nil {
nmgr.nat.Close()
}
nmgr.natMx.Unlock()
}()
discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
natInstance, err := inat.DiscoverNAT(discoverCtx)
natInstance, err := discoverNAT(discoverCtx)
if err != nil {
log.Info("DiscoverNAT error:", err)
close(nmgr.ready)
return
}
nmgr.natMx.Lock()
nmgr.nat = natInstance
nmgr.natMx.Unlock()
close(nmgr.ready)
// sign natManager up for network notifications
// we need to sign up here to avoid missing some notifs
@ -127,10 +135,10 @@ func (nmgr *natManager) sync() {
// doSync syncs the current NAT mappings, removing any outdated mappings and adding any
// new mappings.
func (nmgr *natManager) doSync() {
ports := map[string]map[int]bool{
"tcp": {},
"udp": {},
for e := range nmgr.tracked {
nmgr.tracked[e] = false
}
var newAddresses []entry
for _, maddr := range nmgr.net.ListenAddresses() {
// Strip the IP
maIP, rest := ma.SplitFirst(maddr)
@ -144,10 +152,9 @@ func (nmgr *natManager) doSync() {
continue
}
// Only bother if we're listening on a
// unicast/unspecified IP.
// Only bother if we're listening on an unicast / unspecified IP.
ip := net.IP(maIP.RawValue())
if !(ip.IsGlobalUnicast() || ip.IsUnspecified()) {
if !ip.IsGlobalUnicast() && !ip.IsUnspecified() {
continue
}
@ -166,74 +173,118 @@ func (nmgr *natManager) doSync() {
default:
continue
}
port, err := strconv.ParseUint(proto.Value(), 10, 16)
if err != nil {
// bug in multiaddr
panic(err)
}
ports[protocol][int(port)] = false
e := entry{protocol: protocol, port: int(port)}
if _, ok := nmgr.tracked[e]; ok {
nmgr.tracked[e] = true
} else {
newAddresses = append(newAddresses, e)
}
}
var wg sync.WaitGroup
defer wg.Wait()
// Close old mappings
for _, m := range nmgr.nat.Mappings() {
mappedPort := m.InternalPort()
if _, ok := ports[m.Protocol()][mappedPort]; !ok {
// No longer need this mapping.
wg.Add(1)
go func(m inat.Mapping) {
defer wg.Done()
m.Close()
}(m)
} else {
// already mapped
ports[m.Protocol()][mappedPort] = true
for e, v := range nmgr.tracked {
if !v {
nmgr.nat.RemoveMapping(e.protocol, e.port)
delete(nmgr.tracked, e)
}
}
// Create new mappings.
for proto, pports := range ports {
for port, mapped := range pports {
if mapped {
continue
}
wg.Add(1)
go func(proto string, port int) {
defer wg.Done()
_, err := nmgr.nat.NewMapping(proto, port)
if err != nil {
log.Errorf("failed to port-map %s port %d: %s", proto, port, err)
}
}(proto, port)
for _, e := range newAddresses {
if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil {
log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err)
}
nmgr.tracked[e] = false
}
}
// NAT returns the natManager's nat object. this may be nil, if
// (a) the search process is still ongoing, or (b) the search process
// found no nat. Clients must check whether the return value is nil.
func (nmgr *natManager) NAT() *inat.NAT {
func (nmgr *natManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
nmgr.natMx.Lock()
defer nmgr.natMx.Unlock()
return nmgr.nat
}
type nmgrNetNotifiee natManager
if nmgr.nat == nil { // NAT not yet initialized
return nil
}
func (nn *nmgrNetNotifiee) natManager() *natManager {
return (*natManager)(nn)
}
var found bool
var proto int // ma.P_TCP or ma.P_UDP
transport, rest := ma.SplitFunc(addr, func(c ma.Component) bool {
if found {
return true
}
proto = c.Protocol().Code
found = proto == ma.P_TCP || proto == ma.P_UDP
return false
})
if !manet.IsThinWaist(transport) {
return nil
}
func (nn *nmgrNetNotifiee) Listen(n network.Network, addr ma.Multiaddr) {
nn.natManager().sync()
}
naddr, err := manet.ToNetAddr(transport)
if err != nil {
log.Error("error parsing net multiaddr %q: %s", transport, err)
return nil
}
func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) {
nn.natManager().sync()
var (
ip net.IP
port int
protocol string
)
switch naddr := naddr.(type) {
case *net.TCPAddr:
ip = naddr.IP
port = naddr.Port
protocol = "tcp"
case *net.UDPAddr:
ip = naddr.IP
port = naddr.Port
protocol = "udp"
default:
return nil
}
if !ip.IsGlobalUnicast() && !ip.IsUnspecified() {
// We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc.
return nil
}
extAddr, ok := nmgr.nat.GetMapping(protocol, port)
if !ok {
return nil
}
var mappedAddr net.Addr
switch naddr.(type) {
case *net.TCPAddr:
mappedAddr = net.TCPAddrFromAddrPort(extAddr)
case *net.UDPAddr:
mappedAddr = net.UDPAddrFromAddrPort(extAddr)
}
mappedMaddr, err := manet.FromNetAddr(mappedAddr)
if err != nil {
log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err)
return nil
}
extMaddr := mappedMaddr
if rest != nil {
extMaddr = ma.Join(extMaddr, rest)
}
return extMaddr
}
func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {}
func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {}
type nmgrNetNotifiee natManager
func (nn *nmgrNetNotifiee) natManager() *natManager { return (*natManager)(nn) }
func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { nn.natManager().sync() }
func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) { nn.natManager().sync() }
func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {}
func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {}

108
p2p/host/basic/natmgr_test.go

@ -0,0 +1,108 @@
package basichost
import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
ma "github.com/multiformats/go-multiaddr"
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
"github.com/golang/mock/gomock"
)
func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) {
t.Helper()
ctrl := gomock.NewController(t)
mockNAT = NewMockNAT(ctrl)
origDiscoverNAT := discoverNAT
discoverNAT = func(ctx context.Context) (nat, error) { return mockNAT, nil }
return mockNAT, func() {
discoverNAT = origDiscoverNAT
ctrl.Finish()
}
}
func TestMapping(t *testing.T) {
mockNAT, reset := setupMockNAT(t)
defer reset()
sw := swarmt.GenSwarm(t)
defer sw.Close()
m := newNATManager(sw)
require.Eventually(t, func() bool {
m.natMx.Lock()
defer m.natMx.Unlock()
return m.nat != nil
}, time.Second, time.Millisecond)
externalAddr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 4321)
// pretend that we have a TCP mapping
mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true)
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234")))
// pretend that we have a QUIC mapping
mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true)
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1")))
// pretend that there's no mapping
mockNAT.EXPECT().GetMapping("tcp", 1234).Return(netip.AddrPort{}, false)
require.Nil(t, m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234")))
// make sure this works for WebSocket addresses as well
mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true)
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321/ws"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234/ws")))
// make sure this works for WebTransport addresses as well
mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true)
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1/webtransport"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1/webtransport")))
}
func TestAddAndRemoveListeners(t *testing.T) {
mockNAT, reset := setupMockNAT(t)
defer reset()
sw := swarmt.GenSwarm(t)
defer sw.Close()
m := newNATManager(sw)
require.Eventually(t, func() bool {
m.natMx.Lock()
defer m.natMx.Unlock()
return m.nat != nil
}, time.Second, time.Millisecond)
added := make(chan struct{}, 1)
// add a TCP listener
mockNAT.EXPECT().AddMapping("tcp", 1234).Do(func(string, int) { added <- struct{}{} })
require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/tcp/1234")))
select {
case <-added:
case <-time.After(time.Second):
t.Fatal("didn't receive call to AddMapping")
}
// add a QUIC listener
mockNAT.EXPECT().AddMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} })
require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1")))
select {
case <-added:
case <-time.After(time.Second):
t.Fatal("didn't receive call to AddMapping")
}
// remove the QUIC listener
mockNAT.EXPECT().RemoveMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} })
sw.ListenClose(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))
select {
case <-added:
case <-time.After(time.Second):
t.Fatal("didn't receive call to RemoveMapping")
}
// test shutdown
mockNAT.EXPECT().RemoveMapping("tcp", 1234).MaxTimes(1)
mockNAT.EXPECT().Close().MaxTimes(1)
}

119
p2p/net/nat/mapping.go

@ -1,119 +0,0 @@
package nat
import (
"fmt"
"net"
"sync"
"time"
)
// Mapping represents a port mapping in a NAT.
type Mapping interface {
// NAT returns the NAT object this Mapping belongs to.
NAT() *NAT
// Protocol returns the protocol of this port mapping. This is either
// "tcp" or "udp" as no other protocols are likely to be NAT-supported.
Protocol() string
// InternalPort returns the internal device port. Mapping will continue to
// try to map InternalPort() to an external facing port.
InternalPort() int
// ExternalPort returns the external facing port. If the mapping is not
// established, port will be 0
ExternalPort() int
// ExternalAddr returns the external facing address. If the mapping is not
// established, addr will be nil, and and ErrNoMapping will be returned.
ExternalAddr() (addr net.Addr, err error)
// Close closes the port mapping
Close() error
}
// keeps republishing
type mapping struct {
sync.Mutex // guards all fields
nat *NAT
proto string
intport int
extport int
cached net.IP
cacheTime time.Time
cacheLk sync.Mutex
}
func (m *mapping) NAT() *NAT {
m.Lock()
defer m.Unlock()
return m.nat
}
func (m *mapping) Protocol() string {
m.Lock()
defer m.Unlock()
return m.proto
}
func (m *mapping) InternalPort() int {
m.Lock()
defer m.Unlock()
return m.intport
}
func (m *mapping) ExternalPort() int {
m.Lock()
defer m.Unlock()
return m.extport
}
func (m *mapping) setExternalPort(p int) {
m.Lock()
defer m.Unlock()
m.extport = p
}
func (m *mapping) ExternalAddr() (net.Addr, error) {
m.cacheLk.Lock()
defer m.cacheLk.Unlock()
oport := m.ExternalPort()
if oport == 0 {
// dont even try right now.
return nil, ErrNoMapping
}
if time.Since(m.cacheTime) >= CacheTime {
m.nat.natmu.Lock()
cval, err := m.nat.nat.GetExternalAddress()
m.nat.natmu.Unlock()
if err != nil {
return nil, err
}
m.cached = cval
m.cacheTime = time.Now()
}
switch m.Protocol() {
case "tcp":
return &net.TCPAddr{
IP: m.cached,
Port: oport,
}, nil
case "udp":
return &net.UDPAddr{
IP: m.cached,
Port: oport,
}, nil
default:
panic(fmt.Sprintf("invalid protocol %q", m.Protocol()))
}
}
func (m *mapping) Close() error {
m.nat.removeMapping(m)
return nil
}

124
p2p/net/nat/mock_nat_test.go

@ -0,0 +1,124 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/libp2p/go-nat (interfaces: NAT)
// Package nat is a generated GoMock package.
package nat
import (
net "net"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockNAT is a mock of NAT interface.
type MockNAT struct {
ctrl *gomock.Controller
recorder *MockNATMockRecorder
}
// MockNATMockRecorder is the mock recorder for MockNAT.
type MockNATMockRecorder struct {
mock *MockNAT
}
// NewMockNAT creates a new mock instance.
func NewMockNAT(ctrl *gomock.Controller) *MockNAT {
mock := &MockNAT{ctrl: ctrl}
mock.recorder = &MockNATMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockNAT) EXPECT() *MockNATMockRecorder {
return m.recorder
}
// AddPortMapping mocks base method.
func (m *MockNAT) AddPortMapping(arg0 string, arg1 int, arg2 string, arg3 time.Duration) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPortMapping", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddPortMapping indicates an expected call of AddPortMapping.
func (mr *MockNATMockRecorder) AddPortMapping(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPortMapping", reflect.TypeOf((*MockNAT)(nil).AddPortMapping), arg0, arg1, arg2, arg3)
}
// DeletePortMapping mocks base method.
func (m *MockNAT) DeletePortMapping(arg0 string, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePortMapping", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeletePortMapping indicates an expected call of DeletePortMapping.
func (mr *MockNATMockRecorder) DeletePortMapping(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePortMapping", reflect.TypeOf((*MockNAT)(nil).DeletePortMapping), arg0, arg1)
}
// GetDeviceAddress mocks base method.
func (m *MockNAT) GetDeviceAddress() (net.IP, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDeviceAddress")
ret0, _ := ret[0].(net.IP)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetDeviceAddress indicates an expected call of GetDeviceAddress.
func (mr *MockNATMockRecorder) GetDeviceAddress() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceAddress", reflect.TypeOf((*MockNAT)(nil).GetDeviceAddress))
}
// GetExternalAddress mocks base method.
func (m *MockNAT) GetExternalAddress() (net.IP, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetExternalAddress")
ret0, _ := ret[0].(net.IP)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetExternalAddress indicates an expected call of GetExternalAddress.
func (mr *MockNATMockRecorder) GetExternalAddress() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExternalAddress", reflect.TypeOf((*MockNAT)(nil).GetExternalAddress))
}
// GetInternalAddress mocks base method.
func (m *MockNAT) GetInternalAddress() (net.IP, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetInternalAddress")
ret0, _ := ret[0].(net.IP)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetInternalAddress indicates an expected call of GetInternalAddress.
func (mr *MockNATMockRecorder) GetInternalAddress() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInternalAddress", reflect.TypeOf((*MockNAT)(nil).GetInternalAddress))
}
// Type mocks base method.
func (m *MockNAT) Type() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Type")
ret0, _ := ret[0].(string)
return ret0
}
// Type indicates an expected call of Type.
func (mr *MockNATMockRecorder) Type() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockNAT)(nil).Type))
}

210
p2p/net/nat/nat.go

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/netip"
"sync"
"time"
@ -19,18 +20,30 @@ var log = logging.Logger("nat")
// MappingDuration is a default port mapping duration.
// Port mappings are renewed every (MappingDuration / 3)
const MappingDuration = time.Second * 60
const MappingDuration = time.Minute
// CacheTime is the time a mapping will cache an external address for
const CacheTime = time.Second * 15
const CacheTime = 15 * time.Second
// DiscoverNAT looks for a NAT device in the network and
// returns an object that can manage port mappings.
type entry struct {
protocol string
port int
}
// so we can mock it in tests
var discoverGateway = nat.DiscoverGateway
// DiscoverNAT looks for a NAT device in the network and returns an object that can manage port mappings.
func DiscoverNAT(ctx context.Context) (*NAT, error) {
natInstance, err := nat.DiscoverGateway(ctx)
natInstance, err := discoverGateway(ctx)
if err != nil {
return nil, err
}
var extAddr netip.Addr
extIP, err := natInstance.GetExternalAddress()
if err == nil {
extAddr, _ = netip.AddrFromSlice(extIP)
}
// Log the device addr.
addr, err := natInstance.GetDeviceAddress()
@ -40,7 +53,20 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) {
log.Debug("DiscoverGateway address:", addr)
}
return newNAT(natInstance), nil
ctx, cancel := context.WithCancel(context.Background())
nat := &NAT{
nat: natInstance,
extAddr: extAddr,
mappings: make(map[entry]int),
ctx: ctx,
ctxCancel: cancel,
}
nat.refCount.Add(1)
go func() {
defer nat.refCount.Done()
nat.background()
}()
return nat, nil
}
// NAT is an object that manages address port mappings in
@ -50,6 +76,8 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) {
type NAT struct {
natmu sync.Mutex
nat nat.NAT
// External IP of the NAT. Will be renewed periodically (every CacheTime).
extAddr netip.Addr
refCount sync.WaitGroup
ctx context.Context
@ -57,17 +85,7 @@ type NAT struct {
mappingmu sync.RWMutex // guards mappings
closed bool
mappings map[*mapping]struct{}
}
func newNAT(realNAT nat.NAT) *NAT {
ctx, cancel := context.WithCancel(context.Background())
return &NAT{
nat: realNAT,
mappings: make(map[*mapping]struct{}),
ctx: ctx,
ctxCancel: cancel,
}
mappings map[entry]int
}
// Close shuts down all port mappings. NAT can no longer be used.
@ -81,99 +99,139 @@ func (nat *NAT) Close() error {
return nil
}
// Mappings returns a slice of all NAT mappings
func (nat *NAT) Mappings() []Mapping {
func (nat *NAT) GetMapping(protocol string, port int) (addr netip.AddrPort, found bool) {
nat.mappingmu.Lock()
maps2 := make([]Mapping, 0, len(nat.mappings))
for m := range nat.mappings {
maps2 = append(maps2, m)
defer nat.mappingmu.Unlock()
if !nat.extAddr.IsValid() {
return netip.AddrPort{}, false
}
nat.mappingmu.Unlock()
return maps2
extPort, found := nat.mappings[entry{protocol: protocol, port: port}]
if !found {
return netip.AddrPort{}, false
}
return netip.AddrPortFrom(nat.extAddr, uint16(extPort)), true
}
// NewMapping attempts to construct a mapping on protocol and internal port
// It will also periodically renew the mapping until the returned Mapping
// -- or its parent NAT -- is Closed.
// AddMapping attempts to construct a mapping on protocol and internal port.
// It blocks until a mapping was established. Once added, it periodically renews the mapping.
//
// May not succeed, and mappings may change over time;
// NAT devices may not respect our port requests, and even lie.
// Clients should not store the mapped results, but rather always
// poll our object for the latest mappings.
func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) {
if nat == nil {
return nil, fmt.Errorf("no nat available")
}
func (nat *NAT) AddMapping(protocol string, port int) error {
switch protocol {
case "tcp", "udp":
default:
return nil, fmt.Errorf("invalid protocol: %s", protocol)
}
m := &mapping{
intport: port,
nat: nat,
proto: protocol,
return fmt.Errorf("invalid protocol: %s", protocol)
}
nat.mappingmu.Lock()
defer nat.mappingmu.Unlock()
if nat.closed {
nat.mappingmu.Unlock()
return nil, errors.New("closed")
return errors.New("closed")
}
nat.mappings[m] = struct{}{}
nat.refCount.Add(1)
nat.mappingmu.Unlock()
go nat.refreshMappings(m)
// do it once synchronously, so first mapping is done right away, and before exiting,
// allowing users -- in the optimistic case -- to use results right after.
nat.establishMapping(m)
return m, nil
extPort := nat.establishMapping(protocol, port)
nat.mappings[entry{protocol: protocol, port: port}] = extPort
return nil
}
func (nat *NAT) removeMapping(m *mapping) {
// RemoveMapping removes a port mapping.
// It blocks until the NAT has removed the mapping.
func (nat *NAT) RemoveMapping(protocol string, port int) error {
nat.mappingmu.Lock()
delete(nat.mappings, m)
nat.mappingmu.Unlock()
nat.natmu.Lock()
nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort())
nat.natmu.Unlock()
defer nat.mappingmu.Unlock()
switch protocol {
case "tcp", "udp":
e := entry{protocol: protocol, port: port}
if _, ok := nat.mappings[e]; ok {
delete(nat.mappings, e)
return nat.nat.DeletePortMapping(protocol, port)
}
return errors.New("unknown mapping")
default:
return fmt.Errorf("invalid protocol: %s", protocol)
}
}
func (nat *NAT) refreshMappings(m *mapping) {
defer nat.refCount.Done()
t := time.NewTicker(MappingDuration / 3)
func (nat *NAT) background() {
const mappingUpdate = MappingDuration / 3
now := time.Now()
nextMappingUpdate := now.Add(mappingUpdate)
nextAddrUpdate := now.Add(CacheTime)
t := time.NewTimer(minTime(nextMappingUpdate, nextAddrUpdate).Sub(now)) // don't use a ticker here. We don't know how long establishing the mappings takes.
defer t.Stop()
var in []entry
var out []int // port numbers
for {
select {
case <-t.C:
nat.establishMapping(m)
case now := <-t.C:
if now.After(nextMappingUpdate) {
in = in[:0]
out = out[:0]
nat.mappingmu.Lock()
for e := range nat.mappings {
in = append(in, e)
}
nat.mappingmu.Unlock()
// Establishing the mapping involves network requests.
// Don't hold the mutex, just save the ports.
for _, e := range in {
out = append(out, nat.establishMapping(e.protocol, e.port))
}
nat.mappingmu.Lock()
for i, p := range in {
if _, ok := nat.mappings[p]; !ok {
continue // entry might have been deleted
}
nat.mappings[p] = out[i]
}
nat.mappingmu.Unlock()
nextMappingUpdate = time.Now().Add(mappingUpdate)
}
if now.After(nextAddrUpdate) {
var extAddr netip.Addr
extIP, err := nat.nat.GetExternalAddress()
if err == nil {
extAddr, _ = netip.AddrFromSlice(extIP)
}
nat.extAddr = extAddr
nextAddrUpdate = time.Now().Add(CacheTime)
}
t.Reset(time.Until(minTime(nextAddrUpdate, nextMappingUpdate)))
case <-nat.ctx.Done():
m.Close()
nat.mappingmu.Lock()
for e := range nat.mappings {
delete(nat.mappings, e)
nat.nat.DeletePortMapping(e.protocol, e.port)
}
nat.mappingmu.Unlock()
return
}
}
}
func (nat *NAT) establishMapping(m *mapping) {
oldport := m.ExternalPort()
log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort())
func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPort int) {
log.Debugf("Attempting port map: %s/%d", protocol, internalPort)
const comment = "libp2p"
nat.natmu.Lock()
newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration)
var err error
externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, MappingDuration)
if err != nil {
// Some hardware does not support mappings with timeout, so try that
newport, err = nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, 0)
externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, 0)
}
nat.natmu.Unlock()
if err != nil || newport == 0 {
m.setExternalPort(0) // clear mapping
if err != nil || externalPort == 0 {
// TODO: log.Event
if err != nil {
log.Warnf("failed to establish port mapping: %s", err)
@ -182,12 +240,16 @@ func (nat *NAT) establishMapping(m *mapping) {
}
// we do not close if the mapping failed,
// because it may work again next time.
return
return 0
}
m.setExternalPort(newport)
log.Debugf("NAT Mapping: %d --> %d (%s)", m.ExternalPort(), m.InternalPort(), m.Protocol())
if oldport != 0 && newport != oldport {
log.Debugf("failed to renew same port mapping: ch %d -> %d", oldport, newport)
log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol)
return externalPort
}
func minTime(a, b time.Time) time.Time {
if a.Before(b) {
return a
}
return b
}

69
p2p/net/nat/nat_test.go

@ -0,0 +1,69 @@
package nat
import (
"context"
"errors"
"net"
"net/netip"
"testing"
"github.com/libp2p/go-nat"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
//go:generate sh -c "go run github.com/golang/mock/mockgen -package nat -destination mock_nat_test.go github.com/libp2p/go-nat NAT"
func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) {
t.Helper()
ctrl := gomock.NewController(t)
mockNAT = NewMockNAT(ctrl)
mockNAT.EXPECT().GetDeviceAddress().Return(nil, errors.New("nope")) // is only used for logging
origDiscoverGateway := discoverGateway
discoverGateway = func(ctx context.Context) (nat.NAT, error) { return mockNAT, nil }
return mockNAT, func() {
discoverGateway = origDiscoverGateway
ctrl.Finish()
}
}
func TestAddMapping(t *testing.T) {
mockNAT, reset := setupMockNAT(t)
defer reset()
mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil)
nat, err := DiscoverNAT(context.Background())
require.NoError(t, err)
mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil)
require.NoError(t, nat.AddMapping("tcp", 10000))
_, found := nat.GetMapping("tcp", 9999)
require.False(t, found, "didn't expect a port mapping for unmapped port")
_, found = nat.GetMapping("udp", 10000)
require.False(t, found, "didn't expect a port mapping for unmapped protocol")
mapped, found := nat.GetMapping("tcp", 10000)
require.True(t, found, "expected port mapping")
require.Equal(t, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 1234), mapped)
}
func TestRemoveMapping(t *testing.T) {
mockNAT, reset := setupMockNAT(t)
defer reset()
mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil)
nat, err := DiscoverNAT(context.Background())
require.NoError(t, err)
mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil)
require.NoError(t, nat.AddMapping("tcp", 10000))
_, found := nat.GetMapping("tcp", 10000)
require.True(t, found, "expected port mapping")
require.Error(t, nat.RemoveMapping("tcp", 9999), "expected error for unknown mapping")
mockNAT.EXPECT().DeletePortMapping("tcp", 10000)
require.NoError(t, nat.RemoveMapping("tcp", 10000))
_, found = nat.GetMapping("tcp", 10000)
require.False(t, found, "didn't expect port mapping for deleted mapping")
}
Loading…
Cancel
Save