Browse Source

webrtc: correctly report incoming packet address on muxed connection (#2586)

pull/2604/head
Sukun 1 year ago
committed by GitHub
parent
commit
c2124f76f6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 47
      p2p/transport/webrtc/udpmux/mux.go
  2. 266
      p2p/transport/webrtc/udpmux/mux_test.go
  3. 37
      p2p/transport/webrtc/udpmux/muxed_connection.go

47
p2p/transport/webrtc/udpmux/mux.go

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
logging "github.com/ipfs/go-log/v2"
@ -42,9 +43,15 @@ type UDPMux struct {
queue chan Candidate
mx sync.Mutex
mx sync.Mutex
// ufragMap allows us to multiplex incoming STUN packets based on ufrag
ufragMap map[ufragConnKey]*muxedConnection
addrMap map[string]*muxedConnection
// addrMap allows us to correctly direct incoming packets after the connection
// is established and ufrag isn't available on all packets
addrMap map[string]*muxedConnection
// ufragAddrMap allows cleaning up all addresses from the addrMap once the connection is closed
// During the ICE connectivity checks, the same ufrag might be used on multiple addresses.
ufragAddrMap map[ufragConnKey][]net.Addr
// the context controls the lifecycle of the mux
wg sync.WaitGroup
@ -57,12 +64,13 @@ var _ ice.UDPMux = &UDPMux{}
func NewUDPMux(socket net.PacketConn) *UDPMux {
ctx, cancel := context.WithCancel(context.Background())
mux := &UDPMux{
ctx: ctx,
cancel: cancel,
socket: socket,
ufragMap: make(map[ufragConnKey]*muxedConnection),
addrMap: make(map[string]*muxedConnection),
queue: make(chan Candidate, 32),
ctx: ctx,
cancel: cancel,
socket: socket,
ufragMap: make(map[ufragConnKey]*muxedConnection),
addrMap: make(map[string]*muxedConnection),
ufragAddrMap: make(map[ufragConnKey][]net.Addr),
queue: make(chan Candidate, 32),
}
return mux
@ -130,7 +138,11 @@ func (mux *UDPMux) readLoop() {
n, addr, err := mux.socket.ReadFrom(buf)
if err != nil {
log.Errorf("error reading from socket: %v", err)
if strings.Contains(err.Error(), "use of closed network connection") {
log.Debugf("readLoop exiting: socket %s closed", mux.socket.LocalAddr())
} else {
log.Errorf("error reading from socket %s: %v", mux.socket.LocalAddr(), err)
}
pool.Put(buf)
return
}
@ -157,7 +169,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) {
conn, ok := mux.addrMap[addr.String()]
mux.mx.Unlock()
if ok {
if err := conn.Push(buf); err != nil {
if err := conn.Push(buf, addr); err != nil {
log.Debugf("could not push packet: %v", err)
return false
}
@ -196,7 +208,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) {
}
}
if err := conn.Push(buf); err != nil {
if err := conn.Push(buf, addr); err != nil {
log.Debugf("could not push packet: %v", err)
return false
}
@ -250,9 +262,12 @@ func (mux *UDPMux) RemoveConnByUfrag(ufrag string) {
for _, isIPv6 := range [...]bool{true, false} {
key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}
if conn, ok := mux.ufragMap[key]; ok {
if _, ok := mux.ufragMap[key]; ok {
delete(mux.ufragMap, key)
delete(mux.addrMap, conn.RemoteAddr().String())
for _, addr := range mux.ufragAddrMap[key] {
delete(mux.addrMap, addr.String())
}
delete(mux.ufragAddrMap, key)
}
}
}
@ -264,12 +279,14 @@ func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, _ *UDPMux, addr ne
defer mux.mx.Unlock()
if conn, ok := mux.ufragMap[key]; ok {
mux.addrMap[addr.String()] = conn
mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr)
return false, conn
}
conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) }, addr)
conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) })
mux.ufragMap[key] = conn
mux.addrMap[addr.String()] = conn
mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr)
return true, conn
}

266
p2p/transport/webrtc/udpmux/mux_test.go

@ -1,89 +1,227 @@
package udpmux
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/pion/stun"
"github.com/stretchr/testify/require"
)
var _ net.PacketConn = dummyPacketConn{}
type dummyPacketConn struct{}
// Close implements net.PacketConn
func (dummyPacketConn) Close() error {
return nil
}
// LocalAddr implements net.PacketConn
func (dummyPacketConn) LocalAddr() net.Addr {
return nil
}
// ReadFrom implements net.PacketConn
func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, &net.UDPAddr{}, nil
}
// SetDeadline implements net.PacketConn
func (dummyPacketConn) SetDeadline(t time.Time) error {
return nil
func getSTUNBindingRequest(ufrag string) *stun.Message {
msg := stun.New()
msg.SetType(stun.BindingRequest)
uattr := stun.RawAttribute{
Type: stun.AttrUsername,
Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections
}
uattr.AddTo(msg)
msg.Encode()
return msg
}
// SetReadDeadline implements net.PacketConn
func (dummyPacketConn) SetReadDeadline(t time.Time) error {
return nil
func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) {
t.Helper()
msg := getSTUNBindingRequest(ufrag)
_, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0])
require.NoError(t, err)
}
// SetWriteDeadline implements net.PacketConn
func (dummyPacketConn) SetWriteDeadline(t time.Time) error {
return nil
func newPacketConn(t *testing.T) net.PacketConn {
t.Helper()
udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
c, err := net.ListenUDP("udp", udpPort0)
require.NoError(t, err)
t.Cleanup(func() { c.Close() })
return c
}
// WriteTo implements net.PacketConn
func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, nil
func TestAccept(t *testing.T) {
c := newPacketConn(t)
defer c.Close()
m := NewUDPMux(c)
m.Start()
defer m.Close()
ufrags := []string{"a", "b", "c", "d"}
conns := make([]net.PacketConn, len(ufrags))
for i, ufrag := range ufrags {
conns[i] = newPacketConn(t)
setupMapping(t, ufrag, conns[i], m)
}
for i, ufrag := range ufrags {
c, err := m.Accept(context.Background())
require.NoError(t, err)
require.Equal(t, c.Ufrag, ufrag)
require.Equal(t, c.Addr, conns[i].LocalAddr())
}
for i, ufrag := range ufrags {
// should not be accepted
setupMapping(t, ufrag, conns[i], m)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := m.Accept(ctx)
require.Error(t, err)
// should not be accepted
cc := newPacketConn(t)
setupMapping(t, ufrag, cc, m)
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = m.Accept(ctx)
require.Error(t, err)
}
}
func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool {
m.mx.Lock()
_, ok := m.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}]
m.mx.Unlock()
return ok
func TestGetConn(t *testing.T) {
c := newPacketConn(t)
m := NewUDPMux(c)
m.Start()
defer m.Close()
ufrags := []string{"a", "b", "c", "d"}
conns := make([]net.PacketConn, len(ufrags))
for i, ufrag := range ufrags {
conns[i] = newPacketConn(t)
setupMapping(t, ufrag, conns[i], m)
}
for i, ufrag := range ufrags {
c, err := m.Accept(context.Background())
require.NoError(t, err)
require.Equal(t, c.Ufrag, ufrag)
require.Equal(t, c.Addr, conns[i].LocalAddr())
}
for i, ufrag := range ufrags {
c, err := m.GetConn(ufrag, conns[i].LocalAddr())
require.NoError(t, err)
msg := make([]byte, 100)
_, _, err = c.ReadFrom(msg)
require.NoError(t, err)
}
for i, ufrag := range ufrags {
cc := newPacketConn(t)
// setupMapping of cc to ufrags[0] and remove the stun binding request from the queue
setupMapping(t, ufrag, cc, m)
mc, err := m.GetConn(ufrag, cc.LocalAddr())
require.NoError(t, err)
msg := make([]byte, 100)
_, _, err = mc.ReadFrom(msg)
require.NoError(t, err)
// Write from new connection should provide the new address on ReadFrom
_, err = cc.WriteTo([]byte("test1"), c.LocalAddr())
require.NoError(t, err)
n, addr, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr, cc.LocalAddr())
require.Equal(t, string(msg[:n]), "test1")
// Write from original connection should provide the original address
_, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr())
require.NoError(t, err)
n, addr, err = mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr, conns[i].LocalAddr())
require.Equal(t, string(msg[:n]), "test2")
}
}
var (
addrV4 = net.UDPAddr{IP: net.IPv4zero, Port: 1234}
addrV6 = net.UDPAddr{IP: net.IPv6zero, Port: 1234}
)
func TestUDPMux_GetConn(t *testing.T) {
m := NewUDPMux(dummyPacketConn{})
require.False(t, hasConn(m, "test", false))
conn, err := m.GetConn("test", &addrV4)
func TestRemoveConnByUfrag(t *testing.T) {
c := newPacketConn(t)
m := NewUDPMux(c)
m.Start()
defer m.Close()
// Map each ufrag to two addresses
ufrag := "a"
count := 10
conns := make([]net.PacketConn, count)
for i := 0; i < 10; i++ {
conns[i] = newPacketConn(t)
setupMapping(t, ufrag, conns[i], m)
}
mc, err := m.GetConn(ufrag, conns[0].LocalAddr())
require.NoError(t, err)
require.NotNil(t, conn)
require.False(t, hasConn(m, "test", true))
connv6, err := m.GetConn("test", &addrV6)
require.NoError(t, err)
require.NotNil(t, connv6)
require.NotEqual(t, conn, connv6)
}
func TestUDPMux_RemoveConnectionOnClose(t *testing.T) {
mux := NewUDPMux(dummyPacketConn{})
conn, err := mux.GetConn("test", &addrV4)
for i := 0; i < 10; i++ {
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr())
require.NoError(t, err)
require.Equal(t, mc1, mc)
}
// Now remove the ufrag
m.RemoveConnByUfrag(ufrag)
// All connections should now be associated with b
ufrag = "b"
for i := 0; i < 10; i++ {
setupMapping(t, ufrag, conns[i], m)
}
mc, err = m.GetConn(ufrag, conns[0].LocalAddr())
require.NoError(t, err)
require.NotNil(t, conn)
require.True(t, hasConn(mux, "test", false))
err = conn.Close()
for i := 0; i < 10; i++ {
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr())
require.NoError(t, err)
require.Equal(t, mc1, mc)
}
// Should be different even if the address is the same
mc1, err := m.GetConn("a", conns[0].LocalAddr())
require.NoError(t, err)
require.NotEqual(t, mc1, mc)
}
require.False(t, hasConn(mux, "test", false))
func TestMuxedConnection(t *testing.T) {
c := newPacketConn(t)
m := NewUDPMux(c)
m.Start()
defer m.Close()
msgCount := 3
connCount := 3
ufrags := []string{"a", "b", "c"}
var mu sync.Mutex
addrUfragMap := make(map[string]string)
for _, ufrag := range ufrags {
go func(ufrag string) {
for i := 0; i < connCount; i++ {
cc := newPacketConn(t)
mu.Lock()
addrUfragMap[cc.LocalAddr().String()] = ufrag
mu.Unlock()
setupMapping(t, ufrag, cc, m)
for j := 0; j < msgCount; j++ {
cc.WriteTo([]byte(ufrag), c.LocalAddr())
}
}
}(ufrag)
}
for _, ufrag := range ufrags {
mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant
require.NoError(t, err)
for i := 0; i < connCount; i++ {
msg := make([]byte, 100)
// Read the binding request
_, addr1, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addrUfragMap[addr1.String()], ufrag)
// Read individual msgs
for i := 0; i < msgCount; i++ {
n, addr2, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr2, addr1)
require.Equal(t, ufrag, string(msg[:n]))
}
delete(addrUfragMap, addr1.String())
}
}
require.Equal(t, len(addrUfragMap), 0)
}

37
p2p/transport/webrtc/udpmux/muxed_connection.go

@ -9,6 +9,11 @@ import (
pool "github.com/libp2p/go-buffer-pool"
)
type packet struct {
buf []byte
addr net.Addr
}
var _ net.PacketConn = &muxedConnection{}
const queueLen = 128
@ -21,48 +26,46 @@ type muxedConnection struct {
ctx context.Context
cancel context.CancelFunc
onClose func()
queue chan []byte
remote net.Addr
queue chan packet
mux *UDPMux
}
var _ net.PacketConn = &muxedConnection{}
func newMuxedConnection(mux *UDPMux, onClose func(), remote net.Addr) *muxedConnection {
func newMuxedConnection(mux *UDPMux, onClose func()) *muxedConnection {
ctx, cancel := context.WithCancel(mux.ctx)
return &muxedConnection{
ctx: ctx,
cancel: cancel,
queue: make(chan []byte, queueLen),
queue: make(chan packet, queueLen),
onClose: onClose,
remote: remote,
mux: mux,
}
}
func (c *muxedConnection) Push(buf []byte) error {
func (c *muxedConnection) Push(buf []byte, addr net.Addr) error {
select {
case <-c.ctx.Done():
return errors.New("closed")
default:
}
select {
case c.queue <- buf:
case c.queue <- packet{buf: buf, addr: addr}:
return nil
default:
return errors.New("queue full")
}
}
func (c *muxedConnection) ReadFrom(p []byte) (int, net.Addr, error) {
func (c *muxedConnection) ReadFrom(buf []byte) (int, net.Addr, error) {
select {
case buf := <-c.queue:
n := copy(p, buf) // This might discard parts of the packet, if p is too short
if n < len(buf) {
log.Debugf("short read, had %d, read %d", len(buf), n)
case p := <-c.queue:
n := copy(buf, p.buf) // This might discard parts of the packet, if p is too short
if n < len(p.buf) {
log.Debugf("short read, had %d, read %d", len(p.buf), n)
}
pool.Put(buf)
return n, c.remote, nil
pool.Put(p.buf)
return n, p.addr, nil
case <-c.ctx.Done():
return 0, nil, c.ctx.Err()
}
@ -83,15 +86,15 @@ func (c *muxedConnection) Close() error {
// drain the packet queue
for {
select {
case <-c.queue:
case p := <-c.queue:
pool.Put(p.buf)
default:
return nil
}
}
}
func (c *muxedConnection) LocalAddr() net.Addr { return c.mux.socket.LocalAddr() }
func (c *muxedConnection) RemoteAddr() net.Addr { return c.remote }
func (c *muxedConnection) LocalAddr() net.Addr { return c.mux.socket.LocalAddr() }
func (*muxedConnection) SetDeadline(t time.Time) error {
// no deadline is desired here

Loading…
Cancel
Save