mirror of https://github.com/libp2p/go-libp2p.git
Browse Source
* Add conn_limiter to limit number of conns per ip cidr * Handle the case where we want to call OpenConnection without an IP address * Delete key when count==0gammazero/previous-well-known-resource
Marco Munizaga
6 months ago
committed by
GitHub
4 changed files with 364 additions and 9 deletions
@ -0,0 +1,141 @@ |
|||
package rcmgr |
|||
|
|||
import ( |
|||
"net/netip" |
|||
"sync" |
|||
) |
|||
|
|||
type ConnLimitPerCIDR struct { |
|||
// How many leading 1 bits in the mask
|
|||
BitMask int |
|||
ConnCount int |
|||
} |
|||
|
|||
// 8 for now so that it matches the number of concurrent dials we may do
|
|||
// in swarm_dial.go. With future smart dialing work we should bring this
|
|||
// down
|
|||
var defaultMaxConcurrentConns = 8 |
|||
|
|||
var defaultIP4Limit = ConnLimitPerCIDR{ |
|||
ConnCount: defaultMaxConcurrentConns, |
|||
BitMask: 32, |
|||
} |
|||
var defaultIP6Limits = []ConnLimitPerCIDR{ |
|||
{ |
|||
ConnCount: defaultMaxConcurrentConns, |
|||
BitMask: 56, |
|||
}, |
|||
{ |
|||
ConnCount: 8 * defaultMaxConcurrentConns, |
|||
BitMask: 48, |
|||
}, |
|||
} |
|||
|
|||
func WithLimitPeersPerCIDR(ipv4 []ConnLimitPerCIDR, ipv6 []ConnLimitPerCIDR) Option { |
|||
return func(rm *resourceManager) error { |
|||
if ipv4 != nil { |
|||
rm.connLimiter.connLimitPerCIDRIP4 = ipv4 |
|||
} |
|||
if ipv6 != nil { |
|||
rm.connLimiter.connLimitPerCIDRIP6 = ipv6 |
|||
} |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
type connLimiter struct { |
|||
mu sync.Mutex |
|||
connLimitPerCIDRIP4 []ConnLimitPerCIDR |
|||
connLimitPerCIDRIP6 []ConnLimitPerCIDR |
|||
ip4connsPerLimit []map[string]int |
|||
ip6connsPerLimit []map[string]int |
|||
} |
|||
|
|||
func newConnLimiter() *connLimiter { |
|||
return &connLimiter{ |
|||
connLimitPerCIDRIP4: []ConnLimitPerCIDR{defaultIP4Limit}, |
|||
connLimitPerCIDRIP6: defaultIP6Limits, |
|||
} |
|||
} |
|||
|
|||
// addConn adds a connection for the given IP address. It returns true if the connection is allowed.
|
|||
func (cl *connLimiter) addConn(ip netip.Addr) bool { |
|||
cl.mu.Lock() |
|||
defer cl.mu.Unlock() |
|||
limits := cl.connLimitPerCIDRIP4 |
|||
countsPerLimit := cl.ip4connsPerLimit |
|||
isIP6 := ip.Is6() |
|||
if isIP6 { |
|||
limits = cl.connLimitPerCIDRIP6 |
|||
countsPerLimit = cl.ip6connsPerLimit |
|||
} |
|||
|
|||
if len(countsPerLimit) == 0 && len(limits) > 0 { |
|||
countsPerLimit = make([]map[string]int, len(limits)) |
|||
if isIP6 { |
|||
cl.ip6connsPerLimit = countsPerLimit |
|||
} else { |
|||
cl.ip4connsPerLimit = countsPerLimit |
|||
} |
|||
} |
|||
|
|||
for i, limit := range limits { |
|||
prefix, err := ip.Prefix(limit.BitMask) |
|||
if err != nil { |
|||
return false |
|||
} |
|||
masked := prefix.String() |
|||
|
|||
counts, ok := countsPerLimit[i][masked] |
|||
if !ok { |
|||
if countsPerLimit[i] == nil { |
|||
countsPerLimit[i] = make(map[string]int) |
|||
} |
|||
countsPerLimit[i][masked] = 0 |
|||
} |
|||
if counts+1 > limit.ConnCount { |
|||
return false |
|||
} |
|||
} |
|||
|
|||
// All limit checks passed, now we update the counts
|
|||
for i, limit := range limits { |
|||
prefix, _ := ip.Prefix(limit.BitMask) |
|||
masked := prefix.String() |
|||
countsPerLimit[i][masked]++ |
|||
} |
|||
|
|||
return true |
|||
} |
|||
|
|||
func (cl *connLimiter) rmConn(ip netip.Addr) { |
|||
cl.mu.Lock() |
|||
defer cl.mu.Unlock() |
|||
limits := cl.connLimitPerCIDRIP4 |
|||
countsPerLimit := cl.ip4connsPerLimit |
|||
isIP6 := ip.Is6() |
|||
if isIP6 { |
|||
limits = cl.connLimitPerCIDRIP6 |
|||
countsPerLimit = cl.ip6connsPerLimit |
|||
} |
|||
|
|||
for i, limit := range limits { |
|||
prefix, err := ip.Prefix(limit.BitMask) |
|||
if err != nil { |
|||
// Unexpected since we should have seen this IP before in addConn
|
|||
log.Errorf("unexpected error getting prefix: %v", err) |
|||
continue |
|||
} |
|||
masked := prefix.String() |
|||
counts, ok := countsPerLimit[i][masked] |
|||
if !ok || counts == 0 { |
|||
// Unexpected, but don't panic
|
|||
log.Errorf("unexpected conn count for %s ok=%v count=%v", masked, ok, counts) |
|||
continue |
|||
} |
|||
countsPerLimit[i][masked]-- |
|||
if countsPerLimit[i][masked] == 0 { |
|||
delete(countsPerLimit[i], masked) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,158 @@ |
|||
package rcmgr |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"fmt" |
|||
"net" |
|||
"net/netip" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func TestItLimits(t *testing.T) { |
|||
t.Run("IPv4", func(t *testing.T) { |
|||
ip, err := netip.ParseAddr("1.2.3.4") |
|||
require.NoError(t, err) |
|||
cl := newConnLimiter() |
|||
cl.connLimitPerCIDRIP4[0].ConnCount = 1 |
|||
require.True(t, cl.addConn(ip)) |
|||
|
|||
// should fail the second time
|
|||
require.False(t, cl.addConn(ip)) |
|||
|
|||
otherIP, err := netip.ParseAddr("1.2.3.5") |
|||
require.NoError(t, err) |
|||
require.True(t, cl.addConn(otherIP)) |
|||
}) |
|||
t.Run("IPv6", func(t *testing.T) { |
|||
ip, err := netip.ParseAddr("1:2:3:4::1") |
|||
require.NoError(t, err) |
|||
cl := newConnLimiter() |
|||
original := cl.connLimitPerCIDRIP6[0].ConnCount |
|||
cl.connLimitPerCIDRIP6[0].ConnCount = 1 |
|||
defer func() { |
|||
cl.connLimitPerCIDRIP6[0].ConnCount = original |
|||
}() |
|||
require.True(t, cl.addConn(ip)) |
|||
|
|||
// should fail the second time
|
|||
require.False(t, cl.addConn(ip)) |
|||
otherIPSameSubnet := netip.MustParseAddr("1:2:3:4::2") |
|||
require.False(t, cl.addConn(otherIPSameSubnet)) |
|||
|
|||
otherIP := netip.MustParseAddr("2:2:3:4::2") |
|||
require.True(t, cl.addConn(otherIP)) |
|||
}) |
|||
|
|||
t.Run("IPv6 with multiple limits", func(t *testing.T) { |
|||
cl := newConnLimiter() |
|||
for i := 0; i < defaultMaxConcurrentConns; i++ { |
|||
ip := net.ParseIP("ff:2:3:4::1") |
|||
binary.BigEndian.PutUint16(ip[14:], uint16(i)) |
|||
ipAddr := netip.MustParseAddr(ip.String()) |
|||
require.True(t, cl.addConn(ipAddr)) |
|||
} |
|||
|
|||
// Next one should fail
|
|||
ip := net.ParseIP("ff:2:3:4::1") |
|||
binary.BigEndian.PutUint16(ip[14:], uint16(defaultMaxConcurrentConns+1)) |
|||
require.False(t, cl.addConn(netip.MustParseAddr(ip.String()))) |
|||
|
|||
// But on a different root subnet should work
|
|||
otherIP := netip.MustParseAddr("ffef:2:3::1") |
|||
require.True(t, cl.addConn(otherIP)) |
|||
|
|||
// But too many on the next subnet limit will fail too
|
|||
for i := 0; i < defaultMaxConcurrentConns*8; i++ { |
|||
ip := net.ParseIP("ffef:2:3:4::1") |
|||
binary.BigEndian.PutUint16(ip[5:7], uint16(i)) |
|||
fmt.Println(ip.String()) |
|||
ipAddr := netip.MustParseAddr(ip.String()) |
|||
require.True(t, cl.addConn(ipAddr)) |
|||
} |
|||
|
|||
ip = net.ParseIP("ffef:2:3:4::1") |
|||
binary.BigEndian.PutUint16(ip[5:7], uint16(defaultMaxConcurrentConns*8+1)) |
|||
ipAddr := netip.MustParseAddr(ip.String()) |
|||
require.False(t, cl.addConn(ipAddr)) |
|||
}) |
|||
} |
|||
|
|||
func genIP(data *[]byte) (netip.Addr, bool) { |
|||
if len(*data) < 1 { |
|||
return netip.Addr{}, false |
|||
} |
|||
|
|||
genIP6 := (*data)[0]&0x01 == 1 |
|||
bytesRequired := 4 |
|||
if genIP6 { |
|||
bytesRequired = 16 |
|||
} |
|||
|
|||
if len((*data)[1:]) < bytesRequired { |
|||
return netip.Addr{}, false |
|||
} |
|||
|
|||
*data = (*data)[1:] |
|||
ip, ok := netip.AddrFromSlice((*data)[:bytesRequired]) |
|||
*data = (*data)[bytesRequired:] |
|||
return ip, ok |
|||
} |
|||
|
|||
func FuzzConnLimiter(f *testing.F) { |
|||
// The goal is to try to enter a state where the count is incorrectly 0
|
|||
f.Fuzz(func(t *testing.T, data []byte) { |
|||
ips := make([]netip.Addr, 0, len(data)/5) |
|||
for { |
|||
ip, ok := genIP(&data) |
|||
if !ok { |
|||
break |
|||
} |
|||
ips = append(ips, ip) |
|||
} |
|||
|
|||
cl := newConnLimiter() |
|||
addedConns := make([]netip.Addr, 0, len(ips)) |
|||
for _, ip := range ips { |
|||
if cl.addConn(ip) { |
|||
addedConns = append(addedConns, ip) |
|||
} |
|||
} |
|||
|
|||
addedCount := 0 |
|||
for _, ip := range cl.ip4connsPerLimit { |
|||
for _, count := range ip { |
|||
addedCount += count |
|||
} |
|||
} |
|||
for _, ip := range cl.ip6connsPerLimit { |
|||
for _, count := range ip { |
|||
addedCount += count |
|||
} |
|||
} |
|||
if addedCount == 0 && len(addedConns) > 0 { |
|||
t.Fatalf("added count: %d", addedCount) |
|||
} |
|||
|
|||
for _, ip := range addedConns { |
|||
cl.rmConn(ip) |
|||
} |
|||
|
|||
leftoverCount := 0 |
|||
for _, ip := range cl.ip4connsPerLimit { |
|||
for _, count := range ip { |
|||
leftoverCount += count |
|||
} |
|||
} |
|||
for _, ip := range cl.ip6connsPerLimit { |
|||
for _, count := range ip { |
|||
leftoverCount += count |
|||
} |
|||
} |
|||
if leftoverCount != 0 { |
|||
t.Fatalf("leftover count: %d", leftoverCount) |
|||
} |
|||
}) |
|||
|
|||
} |
Loading…
Reference in new issue