mirror of https://github.com/libp2p/go-libp2p.git
Sukun
1 year ago
committed by
GitHub
3 changed files with 254 additions and 96 deletions
@ -1,89 +1,227 @@ |
|||
package udpmux |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"net" |
|||
"sync" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/pion/stun" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
var _ net.PacketConn = dummyPacketConn{} |
|||
|
|||
type dummyPacketConn struct{} |
|||
|
|||
// Close implements net.PacketConn
|
|||
func (dummyPacketConn) Close() error { |
|||
return nil |
|||
} |
|||
|
|||
// LocalAddr implements net.PacketConn
|
|||
func (dummyPacketConn) LocalAddr() net.Addr { |
|||
return nil |
|||
} |
|||
|
|||
// ReadFrom implements net.PacketConn
|
|||
func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { |
|||
return 0, &net.UDPAddr{}, nil |
|||
} |
|||
|
|||
// SetDeadline implements net.PacketConn
|
|||
func (dummyPacketConn) SetDeadline(t time.Time) error { |
|||
return nil |
|||
func getSTUNBindingRequest(ufrag string) *stun.Message { |
|||
msg := stun.New() |
|||
msg.SetType(stun.BindingRequest) |
|||
uattr := stun.RawAttribute{ |
|||
Type: stun.AttrUsername, |
|||
Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections
|
|||
} |
|||
uattr.AddTo(msg) |
|||
msg.Encode() |
|||
return msg |
|||
} |
|||
|
|||
// SetReadDeadline implements net.PacketConn
|
|||
func (dummyPacketConn) SetReadDeadline(t time.Time) error { |
|||
return nil |
|||
func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) { |
|||
t.Helper() |
|||
msg := getSTUNBindingRequest(ufrag) |
|||
_, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0]) |
|||
require.NoError(t, err) |
|||
} |
|||
|
|||
// SetWriteDeadline implements net.PacketConn
|
|||
func (dummyPacketConn) SetWriteDeadline(t time.Time) error { |
|||
return nil |
|||
func newPacketConn(t *testing.T) net.PacketConn { |
|||
t.Helper() |
|||
udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} |
|||
c, err := net.ListenUDP("udp", udpPort0) |
|||
require.NoError(t, err) |
|||
t.Cleanup(func() { c.Close() }) |
|||
return c |
|||
} |
|||
|
|||
// WriteTo implements net.PacketConn
|
|||
func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { |
|||
return 0, nil |
|||
func TestAccept(t *testing.T) { |
|||
c := newPacketConn(t) |
|||
defer c.Close() |
|||
m := NewUDPMux(c) |
|||
m.Start() |
|||
defer m.Close() |
|||
|
|||
ufrags := []string{"a", "b", "c", "d"} |
|||
conns := make([]net.PacketConn, len(ufrags)) |
|||
for i, ufrag := range ufrags { |
|||
conns[i] = newPacketConn(t) |
|||
setupMapping(t, ufrag, conns[i], m) |
|||
} |
|||
for i, ufrag := range ufrags { |
|||
c, err := m.Accept(context.Background()) |
|||
require.NoError(t, err) |
|||
require.Equal(t, c.Ufrag, ufrag) |
|||
require.Equal(t, c.Addr, conns[i].LocalAddr()) |
|||
} |
|||
|
|||
for i, ufrag := range ufrags { |
|||
// should not be accepted
|
|||
setupMapping(t, ufrag, conns[i], m) |
|||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) |
|||
defer cancel() |
|||
_, err := m.Accept(ctx) |
|||
require.Error(t, err) |
|||
|
|||
// should not be accepted
|
|||
cc := newPacketConn(t) |
|||
setupMapping(t, ufrag, cc, m) |
|||
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) |
|||
defer cancel() |
|||
_, err = m.Accept(ctx) |
|||
require.Error(t, err) |
|||
} |
|||
} |
|||
|
|||
func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool { |
|||
m.mx.Lock() |
|||
_, ok := m.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}] |
|||
m.mx.Unlock() |
|||
return ok |
|||
func TestGetConn(t *testing.T) { |
|||
c := newPacketConn(t) |
|||
m := NewUDPMux(c) |
|||
m.Start() |
|||
defer m.Close() |
|||
|
|||
ufrags := []string{"a", "b", "c", "d"} |
|||
conns := make([]net.PacketConn, len(ufrags)) |
|||
for i, ufrag := range ufrags { |
|||
conns[i] = newPacketConn(t) |
|||
setupMapping(t, ufrag, conns[i], m) |
|||
} |
|||
for i, ufrag := range ufrags { |
|||
c, err := m.Accept(context.Background()) |
|||
require.NoError(t, err) |
|||
require.Equal(t, c.Ufrag, ufrag) |
|||
require.Equal(t, c.Addr, conns[i].LocalAddr()) |
|||
} |
|||
|
|||
for i, ufrag := range ufrags { |
|||
c, err := m.GetConn(ufrag, conns[i].LocalAddr()) |
|||
require.NoError(t, err) |
|||
msg := make([]byte, 100) |
|||
_, _, err = c.ReadFrom(msg) |
|||
require.NoError(t, err) |
|||
} |
|||
|
|||
for i, ufrag := range ufrags { |
|||
cc := newPacketConn(t) |
|||
// setupMapping of cc to ufrags[0] and remove the stun binding request from the queue
|
|||
setupMapping(t, ufrag, cc, m) |
|||
mc, err := m.GetConn(ufrag, cc.LocalAddr()) |
|||
require.NoError(t, err) |
|||
msg := make([]byte, 100) |
|||
_, _, err = mc.ReadFrom(msg) |
|||
require.NoError(t, err) |
|||
|
|||
// Write from new connection should provide the new address on ReadFrom
|
|||
_, err = cc.WriteTo([]byte("test1"), c.LocalAddr()) |
|||
require.NoError(t, err) |
|||
n, addr, err := mc.ReadFrom(msg) |
|||
require.NoError(t, err) |
|||
require.Equal(t, addr, cc.LocalAddr()) |
|||
require.Equal(t, string(msg[:n]), "test1") |
|||
|
|||
// Write from original connection should provide the original address
|
|||
_, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr()) |
|||
require.NoError(t, err) |
|||
n, addr, err = mc.ReadFrom(msg) |
|||
require.NoError(t, err) |
|||
require.Equal(t, addr, conns[i].LocalAddr()) |
|||
require.Equal(t, string(msg[:n]), "test2") |
|||
} |
|||
} |
|||
|
|||
var ( |
|||
addrV4 = net.UDPAddr{IP: net.IPv4zero, Port: 1234} |
|||
addrV6 = net.UDPAddr{IP: net.IPv6zero, Port: 1234} |
|||
) |
|||
|
|||
func TestUDPMux_GetConn(t *testing.T) { |
|||
m := NewUDPMux(dummyPacketConn{}) |
|||
require.False(t, hasConn(m, "test", false)) |
|||
conn, err := m.GetConn("test", &addrV4) |
|||
func TestRemoveConnByUfrag(t *testing.T) { |
|||
c := newPacketConn(t) |
|||
m := NewUDPMux(c) |
|||
m.Start() |
|||
defer m.Close() |
|||
|
|||
// Map each ufrag to two addresses
|
|||
ufrag := "a" |
|||
count := 10 |
|||
conns := make([]net.PacketConn, count) |
|||
for i := 0; i < 10; i++ { |
|||
conns[i] = newPacketConn(t) |
|||
setupMapping(t, ufrag, conns[i], m) |
|||
} |
|||
mc, err := m.GetConn(ufrag, conns[0].LocalAddr()) |
|||
require.NoError(t, err) |
|||
require.NotNil(t, conn) |
|||
|
|||
require.False(t, hasConn(m, "test", true)) |
|||
connv6, err := m.GetConn("test", &addrV6) |
|||
require.NoError(t, err) |
|||
require.NotNil(t, connv6) |
|||
|
|||
require.NotEqual(t, conn, connv6) |
|||
} |
|||
|
|||
func TestUDPMux_RemoveConnectionOnClose(t *testing.T) { |
|||
mux := NewUDPMux(dummyPacketConn{}) |
|||
conn, err := mux.GetConn("test", &addrV4) |
|||
for i := 0; i < 10; i++ { |
|||
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) |
|||
require.NoError(t, err) |
|||
require.Equal(t, mc1, mc) |
|||
} |
|||
|
|||
// Now remove the ufrag
|
|||
m.RemoveConnByUfrag(ufrag) |
|||
|
|||
// All connections should now be associated with b
|
|||
ufrag = "b" |
|||
for i := 0; i < 10; i++ { |
|||
setupMapping(t, ufrag, conns[i], m) |
|||
} |
|||
mc, err = m.GetConn(ufrag, conns[0].LocalAddr()) |
|||
require.NoError(t, err) |
|||
require.NotNil(t, conn) |
|||
|
|||
require.True(t, hasConn(mux, "test", false)) |
|||
|
|||
err = conn.Close() |
|||
for i := 0; i < 10; i++ { |
|||
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) |
|||
require.NoError(t, err) |
|||
require.Equal(t, mc1, mc) |
|||
} |
|||
|
|||
// Should be different even if the address is the same
|
|||
mc1, err := m.GetConn("a", conns[0].LocalAddr()) |
|||
require.NoError(t, err) |
|||
require.NotEqual(t, mc1, mc) |
|||
} |
|||
|
|||
require.False(t, hasConn(mux, "test", false)) |
|||
func TestMuxedConnection(t *testing.T) { |
|||
c := newPacketConn(t) |
|||
m := NewUDPMux(c) |
|||
m.Start() |
|||
defer m.Close() |
|||
|
|||
msgCount := 3 |
|||
connCount := 3 |
|||
|
|||
ufrags := []string{"a", "b", "c"} |
|||
var mu sync.Mutex |
|||
addrUfragMap := make(map[string]string) |
|||
for _, ufrag := range ufrags { |
|||
go func(ufrag string) { |
|||
for i := 0; i < connCount; i++ { |
|||
cc := newPacketConn(t) |
|||
mu.Lock() |
|||
addrUfragMap[cc.LocalAddr().String()] = ufrag |
|||
mu.Unlock() |
|||
setupMapping(t, ufrag, cc, m) |
|||
for j := 0; j < msgCount; j++ { |
|||
cc.WriteTo([]byte(ufrag), c.LocalAddr()) |
|||
} |
|||
} |
|||
}(ufrag) |
|||
} |
|||
|
|||
for _, ufrag := range ufrags { |
|||
mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant
|
|||
require.NoError(t, err) |
|||
for i := 0; i < connCount; i++ { |
|||
msg := make([]byte, 100) |
|||
// Read the binding request
|
|||
_, addr1, err := mc.ReadFrom(msg) |
|||
require.NoError(t, err) |
|||
require.Equal(t, addrUfragMap[addr1.String()], ufrag) |
|||
// Read individual msgs
|
|||
for i := 0; i < msgCount; i++ { |
|||
n, addr2, err := mc.ReadFrom(msg) |
|||
require.NoError(t, err) |
|||
require.Equal(t, addr2, addr1) |
|||
require.Equal(t, ufrag, string(msg[:n])) |
|||
} |
|||
delete(addrUfragMap, addr1.String()) |
|||
} |
|||
} |
|||
require.Equal(t, len(addrUfragMap), 0) |
|||
} |
|||
|
Loading…
Reference in new issue