mirror of https://github.com/libp2p/go-libp2p.git
Marten Seemann
2 years ago
committed by
GitHub
9 changed files with 662 additions and 350 deletions
@ -0,0 +1,92 @@ |
|||
// Code generated by MockGen. DO NOT EDIT.
|
|||
// Source: github.com/libp2p/go-libp2p/p2p/host/basic (interfaces: NAT)
|
|||
|
|||
// Package basichost is a generated GoMock package.
|
|||
package basichost |
|||
|
|||
import ( |
|||
netip "net/netip" |
|||
reflect "reflect" |
|||
|
|||
gomock "github.com/golang/mock/gomock" |
|||
) |
|||
|
|||
// MockNAT is a mock of NAT interface.
|
|||
type MockNAT struct { |
|||
ctrl *gomock.Controller |
|||
recorder *MockNATMockRecorder |
|||
} |
|||
|
|||
// MockNATMockRecorder is the mock recorder for MockNAT.
|
|||
type MockNATMockRecorder struct { |
|||
mock *MockNAT |
|||
} |
|||
|
|||
// NewMockNAT creates a new mock instance.
|
|||
func NewMockNAT(ctrl *gomock.Controller) *MockNAT { |
|||
mock := &MockNAT{ctrl: ctrl} |
|||
mock.recorder = &MockNATMockRecorder{mock} |
|||
return mock |
|||
} |
|||
|
|||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|||
func (m *MockNAT) EXPECT() *MockNATMockRecorder { |
|||
return m.recorder |
|||
} |
|||
|
|||
// AddMapping mocks base method.
|
|||
func (m *MockNAT) AddMapping(arg0 string, arg1 int) error { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "AddMapping", arg0, arg1) |
|||
ret0, _ := ret[0].(error) |
|||
return ret0 |
|||
} |
|||
|
|||
// AddMapping indicates an expected call of AddMapping.
|
|||
func (mr *MockNATMockRecorder) AddMapping(arg0, arg1 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMapping", reflect.TypeOf((*MockNAT)(nil).AddMapping), arg0, arg1) |
|||
} |
|||
|
|||
// Close mocks base method.
|
|||
func (m *MockNAT) Close() error { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "Close") |
|||
ret0, _ := ret[0].(error) |
|||
return ret0 |
|||
} |
|||
|
|||
// Close indicates an expected call of Close.
|
|||
func (mr *MockNATMockRecorder) Close() *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNAT)(nil).Close)) |
|||
} |
|||
|
|||
// GetMapping mocks base method.
|
|||
func (m *MockNAT) GetMapping(arg0 string, arg1 int) (netip.AddrPort, bool) { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "GetMapping", arg0, arg1) |
|||
ret0, _ := ret[0].(netip.AddrPort) |
|||
ret1, _ := ret[1].(bool) |
|||
return ret0, ret1 |
|||
} |
|||
|
|||
// GetMapping indicates an expected call of GetMapping.
|
|||
func (mr *MockNATMockRecorder) GetMapping(arg0, arg1 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMapping", reflect.TypeOf((*MockNAT)(nil).GetMapping), arg0, arg1) |
|||
} |
|||
|
|||
// RemoveMapping mocks base method.
|
|||
func (m *MockNAT) RemoveMapping(arg0 string, arg1 int) error { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "RemoveMapping", arg0, arg1) |
|||
ret0, _ := ret[0].(error) |
|||
return ret0 |
|||
} |
|||
|
|||
// RemoveMapping indicates an expected call of RemoveMapping.
|
|||
func (mr *MockNATMockRecorder) RemoveMapping(arg0, arg1 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMapping", reflect.TypeOf((*MockNAT)(nil).RemoveMapping), arg0, arg1) |
|||
} |
@ -0,0 +1,6 @@ |
|||
//go:build gomock || generate
|
|||
|
|||
package basichost |
|||
|
|||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT"
|
|||
type NAT nat |
@ -0,0 +1,108 @@ |
|||
package basichost |
|||
|
|||
import ( |
|||
"context" |
|||
"net/netip" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/require" |
|||
|
|||
ma "github.com/multiformats/go-multiaddr" |
|||
|
|||
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" |
|||
|
|||
"github.com/golang/mock/gomock" |
|||
) |
|||
|
|||
func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) { |
|||
t.Helper() |
|||
ctrl := gomock.NewController(t) |
|||
mockNAT = NewMockNAT(ctrl) |
|||
origDiscoverNAT := discoverNAT |
|||
discoverNAT = func(ctx context.Context) (nat, error) { return mockNAT, nil } |
|||
return mockNAT, func() { |
|||
discoverNAT = origDiscoverNAT |
|||
ctrl.Finish() |
|||
} |
|||
} |
|||
|
|||
func TestMapping(t *testing.T) { |
|||
mockNAT, reset := setupMockNAT(t) |
|||
defer reset() |
|||
|
|||
sw := swarmt.GenSwarm(t) |
|||
defer sw.Close() |
|||
m := newNATManager(sw) |
|||
require.Eventually(t, func() bool { |
|||
m.natMx.Lock() |
|||
defer m.natMx.Unlock() |
|||
return m.nat != nil |
|||
}, time.Second, time.Millisecond) |
|||
externalAddr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 4321) |
|||
// pretend that we have a TCP mapping
|
|||
mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true) |
|||
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) |
|||
|
|||
// pretend that we have a QUIC mapping
|
|||
mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true) |
|||
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))) |
|||
|
|||
// pretend that there's no mapping
|
|||
mockNAT.EXPECT().GetMapping("tcp", 1234).Return(netip.AddrPort{}, false) |
|||
require.Nil(t, m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) |
|||
|
|||
// make sure this works for WebSocket addresses as well
|
|||
mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true) |
|||
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321/ws"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234/ws"))) |
|||
|
|||
// make sure this works for WebTransport addresses as well
|
|||
mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true) |
|||
require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1/webtransport"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1/webtransport"))) |
|||
} |
|||
|
|||
func TestAddAndRemoveListeners(t *testing.T) { |
|||
mockNAT, reset := setupMockNAT(t) |
|||
defer reset() |
|||
|
|||
sw := swarmt.GenSwarm(t) |
|||
defer sw.Close() |
|||
m := newNATManager(sw) |
|||
require.Eventually(t, func() bool { |
|||
m.natMx.Lock() |
|||
defer m.natMx.Unlock() |
|||
return m.nat != nil |
|||
}, time.Second, time.Millisecond) |
|||
|
|||
added := make(chan struct{}, 1) |
|||
// add a TCP listener
|
|||
mockNAT.EXPECT().AddMapping("tcp", 1234).Do(func(string, int) { added <- struct{}{} }) |
|||
require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) |
|||
select { |
|||
case <-added: |
|||
case <-time.After(time.Second): |
|||
t.Fatal("didn't receive call to AddMapping") |
|||
} |
|||
|
|||
// add a QUIC listener
|
|||
mockNAT.EXPECT().AddMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} }) |
|||
require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))) |
|||
select { |
|||
case <-added: |
|||
case <-time.After(time.Second): |
|||
t.Fatal("didn't receive call to AddMapping") |
|||
} |
|||
|
|||
// remove the QUIC listener
|
|||
mockNAT.EXPECT().RemoveMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} }) |
|||
sw.ListenClose(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1")) |
|||
select { |
|||
case <-added: |
|||
case <-time.After(time.Second): |
|||
t.Fatal("didn't receive call to RemoveMapping") |
|||
} |
|||
|
|||
// test shutdown
|
|||
mockNAT.EXPECT().RemoveMapping("tcp", 1234).MaxTimes(1) |
|||
mockNAT.EXPECT().Close().MaxTimes(1) |
|||
} |
@ -1,119 +0,0 @@ |
|||
package nat |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
// Mapping represents a port mapping in a NAT.
|
|||
type Mapping interface { |
|||
// NAT returns the NAT object this Mapping belongs to.
|
|||
NAT() *NAT |
|||
|
|||
// Protocol returns the protocol of this port mapping. This is either
|
|||
// "tcp" or "udp" as no other protocols are likely to be NAT-supported.
|
|||
Protocol() string |
|||
|
|||
// InternalPort returns the internal device port. Mapping will continue to
|
|||
// try to map InternalPort() to an external facing port.
|
|||
InternalPort() int |
|||
|
|||
// ExternalPort returns the external facing port. If the mapping is not
|
|||
// established, port will be 0
|
|||
ExternalPort() int |
|||
|
|||
// ExternalAddr returns the external facing address. If the mapping is not
|
|||
// established, addr will be nil, and and ErrNoMapping will be returned.
|
|||
ExternalAddr() (addr net.Addr, err error) |
|||
|
|||
// Close closes the port mapping
|
|||
Close() error |
|||
} |
|||
|
|||
// keeps republishing
|
|||
type mapping struct { |
|||
sync.Mutex // guards all fields
|
|||
|
|||
nat *NAT |
|||
proto string |
|||
intport int |
|||
extport int |
|||
|
|||
cached net.IP |
|||
cacheTime time.Time |
|||
cacheLk sync.Mutex |
|||
} |
|||
|
|||
func (m *mapping) NAT() *NAT { |
|||
m.Lock() |
|||
defer m.Unlock() |
|||
return m.nat |
|||
} |
|||
|
|||
func (m *mapping) Protocol() string { |
|||
m.Lock() |
|||
defer m.Unlock() |
|||
return m.proto |
|||
} |
|||
|
|||
func (m *mapping) InternalPort() int { |
|||
m.Lock() |
|||
defer m.Unlock() |
|||
return m.intport |
|||
} |
|||
|
|||
func (m *mapping) ExternalPort() int { |
|||
m.Lock() |
|||
defer m.Unlock() |
|||
return m.extport |
|||
} |
|||
|
|||
func (m *mapping) setExternalPort(p int) { |
|||
m.Lock() |
|||
defer m.Unlock() |
|||
m.extport = p |
|||
} |
|||
|
|||
func (m *mapping) ExternalAddr() (net.Addr, error) { |
|||
m.cacheLk.Lock() |
|||
defer m.cacheLk.Unlock() |
|||
oport := m.ExternalPort() |
|||
if oport == 0 { |
|||
// dont even try right now.
|
|||
return nil, ErrNoMapping |
|||
} |
|||
|
|||
if time.Since(m.cacheTime) >= CacheTime { |
|||
m.nat.natmu.Lock() |
|||
cval, err := m.nat.nat.GetExternalAddress() |
|||
m.nat.natmu.Unlock() |
|||
|
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
m.cached = cval |
|||
m.cacheTime = time.Now() |
|||
} |
|||
switch m.Protocol() { |
|||
case "tcp": |
|||
return &net.TCPAddr{ |
|||
IP: m.cached, |
|||
Port: oport, |
|||
}, nil |
|||
case "udp": |
|||
return &net.UDPAddr{ |
|||
IP: m.cached, |
|||
Port: oport, |
|||
}, nil |
|||
default: |
|||
panic(fmt.Sprintf("invalid protocol %q", m.Protocol())) |
|||
} |
|||
} |
|||
|
|||
func (m *mapping) Close() error { |
|||
m.nat.removeMapping(m) |
|||
return nil |
|||
} |
@ -0,0 +1,124 @@ |
|||
// Code generated by MockGen. DO NOT EDIT.
|
|||
// Source: github.com/libp2p/go-nat (interfaces: NAT)
|
|||
|
|||
// Package nat is a generated GoMock package.
|
|||
package nat |
|||
|
|||
import ( |
|||
net "net" |
|||
reflect "reflect" |
|||
time "time" |
|||
|
|||
gomock "github.com/golang/mock/gomock" |
|||
) |
|||
|
|||
// MockNAT is a mock of NAT interface.
|
|||
type MockNAT struct { |
|||
ctrl *gomock.Controller |
|||
recorder *MockNATMockRecorder |
|||
} |
|||
|
|||
// MockNATMockRecorder is the mock recorder for MockNAT.
|
|||
type MockNATMockRecorder struct { |
|||
mock *MockNAT |
|||
} |
|||
|
|||
// NewMockNAT creates a new mock instance.
|
|||
func NewMockNAT(ctrl *gomock.Controller) *MockNAT { |
|||
mock := &MockNAT{ctrl: ctrl} |
|||
mock.recorder = &MockNATMockRecorder{mock} |
|||
return mock |
|||
} |
|||
|
|||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|||
func (m *MockNAT) EXPECT() *MockNATMockRecorder { |
|||
return m.recorder |
|||
} |
|||
|
|||
// AddPortMapping mocks base method.
|
|||
func (m *MockNAT) AddPortMapping(arg0 string, arg1 int, arg2 string, arg3 time.Duration) (int, error) { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "AddPortMapping", arg0, arg1, arg2, arg3) |
|||
ret0, _ := ret[0].(int) |
|||
ret1, _ := ret[1].(error) |
|||
return ret0, ret1 |
|||
} |
|||
|
|||
// AddPortMapping indicates an expected call of AddPortMapping.
|
|||
func (mr *MockNATMockRecorder) AddPortMapping(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPortMapping", reflect.TypeOf((*MockNAT)(nil).AddPortMapping), arg0, arg1, arg2, arg3) |
|||
} |
|||
|
|||
// DeletePortMapping mocks base method.
|
|||
func (m *MockNAT) DeletePortMapping(arg0 string, arg1 int) error { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "DeletePortMapping", arg0, arg1) |
|||
ret0, _ := ret[0].(error) |
|||
return ret0 |
|||
} |
|||
|
|||
// DeletePortMapping indicates an expected call of DeletePortMapping.
|
|||
func (mr *MockNATMockRecorder) DeletePortMapping(arg0, arg1 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePortMapping", reflect.TypeOf((*MockNAT)(nil).DeletePortMapping), arg0, arg1) |
|||
} |
|||
|
|||
// GetDeviceAddress mocks base method.
|
|||
func (m *MockNAT) GetDeviceAddress() (net.IP, error) { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "GetDeviceAddress") |
|||
ret0, _ := ret[0].(net.IP) |
|||
ret1, _ := ret[1].(error) |
|||
return ret0, ret1 |
|||
} |
|||
|
|||
// GetDeviceAddress indicates an expected call of GetDeviceAddress.
|
|||
func (mr *MockNATMockRecorder) GetDeviceAddress() *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceAddress", reflect.TypeOf((*MockNAT)(nil).GetDeviceAddress)) |
|||
} |
|||
|
|||
// GetExternalAddress mocks base method.
|
|||
func (m *MockNAT) GetExternalAddress() (net.IP, error) { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "GetExternalAddress") |
|||
ret0, _ := ret[0].(net.IP) |
|||
ret1, _ := ret[1].(error) |
|||
return ret0, ret1 |
|||
} |
|||
|
|||
// GetExternalAddress indicates an expected call of GetExternalAddress.
|
|||
func (mr *MockNATMockRecorder) GetExternalAddress() *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExternalAddress", reflect.TypeOf((*MockNAT)(nil).GetExternalAddress)) |
|||
} |
|||
|
|||
// GetInternalAddress mocks base method.
|
|||
func (m *MockNAT) GetInternalAddress() (net.IP, error) { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "GetInternalAddress") |
|||
ret0, _ := ret[0].(net.IP) |
|||
ret1, _ := ret[1].(error) |
|||
return ret0, ret1 |
|||
} |
|||
|
|||
// GetInternalAddress indicates an expected call of GetInternalAddress.
|
|||
func (mr *MockNATMockRecorder) GetInternalAddress() *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInternalAddress", reflect.TypeOf((*MockNAT)(nil).GetInternalAddress)) |
|||
} |
|||
|
|||
// Type mocks base method.
|
|||
func (m *MockNAT) Type() string { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "Type") |
|||
ret0, _ := ret[0].(string) |
|||
return ret0 |
|||
} |
|||
|
|||
// Type indicates an expected call of Type.
|
|||
func (mr *MockNATMockRecorder) Type() *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockNAT)(nil).Type)) |
|||
} |
@ -0,0 +1,69 @@ |
|||
package nat |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"net" |
|||
"net/netip" |
|||
"testing" |
|||
|
|||
"github.com/libp2p/go-nat" |
|||
|
|||
"github.com/golang/mock/gomock" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
//go:generate sh -c "go run github.com/golang/mock/mockgen -package nat -destination mock_nat_test.go github.com/libp2p/go-nat NAT"
|
|||
|
|||
func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) { |
|||
t.Helper() |
|||
ctrl := gomock.NewController(t) |
|||
mockNAT = NewMockNAT(ctrl) |
|||
mockNAT.EXPECT().GetDeviceAddress().Return(nil, errors.New("nope")) // is only used for logging
|
|||
origDiscoverGateway := discoverGateway |
|||
discoverGateway = func(ctx context.Context) (nat.NAT, error) { return mockNAT, nil } |
|||
return mockNAT, func() { |
|||
discoverGateway = origDiscoverGateway |
|||
ctrl.Finish() |
|||
} |
|||
} |
|||
|
|||
func TestAddMapping(t *testing.T) { |
|||
mockNAT, reset := setupMockNAT(t) |
|||
defer reset() |
|||
|
|||
mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil) |
|||
nat, err := DiscoverNAT(context.Background()) |
|||
require.NoError(t, err) |
|||
|
|||
mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil) |
|||
require.NoError(t, nat.AddMapping("tcp", 10000)) |
|||
|
|||
_, found := nat.GetMapping("tcp", 9999) |
|||
require.False(t, found, "didn't expect a port mapping for unmapped port") |
|||
_, found = nat.GetMapping("udp", 10000) |
|||
require.False(t, found, "didn't expect a port mapping for unmapped protocol") |
|||
mapped, found := nat.GetMapping("tcp", 10000) |
|||
require.True(t, found, "expected port mapping") |
|||
require.Equal(t, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 1234), mapped) |
|||
} |
|||
|
|||
func TestRemoveMapping(t *testing.T) { |
|||
mockNAT, reset := setupMockNAT(t) |
|||
defer reset() |
|||
|
|||
mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil) |
|||
nat, err := DiscoverNAT(context.Background()) |
|||
require.NoError(t, err) |
|||
mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil) |
|||
require.NoError(t, nat.AddMapping("tcp", 10000)) |
|||
_, found := nat.GetMapping("tcp", 10000) |
|||
require.True(t, found, "expected port mapping") |
|||
|
|||
require.Error(t, nat.RemoveMapping("tcp", 9999), "expected error for unknown mapping") |
|||
mockNAT.EXPECT().DeletePortMapping("tcp", 10000) |
|||
require.NoError(t, nat.RemoveMapping("tcp", 10000)) |
|||
|
|||
_, found = nat.GetMapping("tcp", 10000) |
|||
require.False(t, found, "didn't expect port mapping for deleted mapping") |
|||
} |
Loading…
Reference in new issue