diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index 67d48e79a..f75504bc5 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -8,54 +8,23 @@ import ( "io/ioutil" mrand "math/rand" "net" - "sync" "sync/atomic" "time" - "github.com/libp2p/go-libp2p-core/control" + gomock "github.com/golang/mock/gomock" ic "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" + n "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type mockGater struct { - lk sync.Mutex - acceptAll bool - blockedPeer peer.ID -} - -func (c *mockGater) InterceptAccept(addrs network.ConnMultiaddrs) bool { - c.lk.Lock() - defer c.lk.Unlock() - return c.acceptAll || !manet.IsIPLoopback(addrs.RemoteMultiaddr()) -} - -func (c *mockGater) InterceptPeerDial(p peer.ID) (allow bool) { - return true -} - -func (c *mockGater) InterceptAddrDial(peer.ID, ma.Multiaddr) (allow bool) { - return true -} - -func (c *mockGater) InterceptSecured(_ network.Direction, p peer.ID, _ network.ConnMultiaddrs) (allow bool) { - c.lk.Lock() - defer c.lk.Unlock() - return p != c.blockedPeer -} - -func (c *mockGater) InterceptUpgraded(network.Conn) (allow bool, reason control.DisconnectReason) { - return true, 0 -} - +//go:generate sh -c "mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" var _ = Describe("Connection", func() { var ( serverKey, clientKey ic.PrivKey @@ -200,11 +169,11 @@ var _ = Describe("Connection", func() { }) It("gates accepted connections", func() { - testMA, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234/quic") - Expect(err).ToNot(HaveOccurred()) - cg := &mockGater{} - Expect(cg.InterceptAccept(&connAddrs{rmAddr: testMA})).To(BeFalse()) - + cg := NewMockConnectionGater(mockCtrl) + var allow bool + cg.EXPECT().InterceptAccept(gomock.Any()).DoAndReturn(func(n.ConnMultiaddrs) bool { + return allow + }).AnyTimes() serverTransport, err := NewTransport(serverKey, nil, cg) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") @@ -220,10 +189,8 @@ var _ = Describe("Connection", func() { Expect(err.(net.Error).Timeout()).To(BeTrue()) // now allow the address and make sure the connection goes through + allow = true clientTransport.(*transport).clientConfig.HandshakeTimeout = 2 * time.Second - cg.lk.Lock() - cg.acceptAll = true - cg.lk.Unlock() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) conn.Close() @@ -235,7 +202,10 @@ var _ = Describe("Connection", func() { ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - cg := &mockGater{acceptAll: true, blockedPeer: serverID} + cg := NewMockConnectionGater(mockCtrl) + cg.EXPECT().InterceptAccept(gomock.Any()).Return(true).AnyTimes() + cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()) + clientTransport, err := NewTransport(clientKey, nil, cg) Expect(err).ToNot(HaveOccurred()) @@ -245,10 +215,8 @@ var _ = Describe("Connection", func() { Expect(err).To(HaveOccurred()) // now allow the peerId and make sure the connection goes through + cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) clientTransport.(*transport).clientConfig.HandshakeTimeout = 2 * time.Second - cg.lk.Lock() - cg.blockedPeer = "none" - cg.lk.Unlock() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) conn.Close() diff --git a/p2p/transport/quic/libp2pquic_suite_test.go b/p2p/transport/quic/libp2pquic_suite_test.go index 5905763c0..0415fed75 100644 --- a/p2p/transport/quic/libp2pquic_suite_test.go +++ b/p2p/transport/quic/libp2pquic_suite_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + gomock "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" @@ -27,6 +28,7 @@ var ( garbageCollectIntervalOrig time.Duration maxUnusedDurationOrig time.Duration origQuicConfig *quic.Config + mockCtrl *gomock.Controller ) func isGarbageCollectorRunning() bool { @@ -36,6 +38,8 @@ func isGarbageCollectorRunning() bool { } var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) + Expect(isGarbageCollectorRunning()).To(BeFalse()) garbageCollectIntervalOrig = garbageCollectInterval maxUnusedDurationOrig = maxUnusedDuration @@ -46,6 +50,8 @@ var _ = BeforeEach(func() { }) var _ = AfterEach(func() { + mockCtrl.Finish() + Eventually(isGarbageCollectorRunning).Should(BeFalse()) garbageCollectInterval = garbageCollectIntervalOrig maxUnusedDuration = maxUnusedDurationOrig diff --git a/p2p/transport/quic/mock_connection_gater_test.go b/p2p/transport/quic/mock_connection_gater_test.go new file mode 100644 index 000000000..899a0c6d5 --- /dev/null +++ b/p2p/transport/quic/mock_connection_gater_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p-core/connmgr (interfaces: ConnectionGater) + +// Package libp2pquic is a generated GoMock package. +package libp2pquic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + control "github.com/libp2p/go-libp2p-core/control" + network "github.com/libp2p/go-libp2p-core/network" + peer "github.com/libp2p/go-libp2p-core/peer" + multiaddr "github.com/multiformats/go-multiaddr" +) + +// MockConnectionGater is a mock of ConnectionGater interface +type MockConnectionGater struct { + ctrl *gomock.Controller + recorder *MockConnectionGaterMockRecorder +} + +// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater +type MockConnectionGaterMockRecorder struct { + mock *MockConnectionGater +} + +// NewMockConnectionGater creates a new mock instance +func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater { + mock := &MockConnectionGater{ctrl: ctrl} + mock.recorder = &MockConnectionGaterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder { + return m.recorder +} + +// InterceptAccept mocks base method +func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAccept", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAccept indicates an expected call of InterceptAccept +func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0) +} + +// InterceptAddrDial mocks base method +func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAddrDial indicates an expected call of InterceptAddrDial +func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1) +} + +// InterceptPeerDial mocks base method +func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptPeerDial", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptPeerDial indicates an expected call of InterceptPeerDial +func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0) +} + +// InterceptSecured mocks base method +func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptSecured indicates an expected call of InterceptSecured +func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2) +} + +// InterceptUpgraded mocks base method +func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptUpgraded", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(control.DisconnectReason) + return ret0, ret1 +} + +// InterceptUpgraded indicates an expected call of InterceptUpgraded +func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0) +}