mirror of https://github.com/libp2p/go-libp2p.git
Marten Seemann
3 years ago
10 changed files with 2494 additions and 0 deletions
@ -0,0 +1,54 @@ |
|||
package connmgr |
|||
|
|||
import ( |
|||
"math/rand" |
|||
"sync" |
|||
"testing" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
|
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func randomConns(tb testing.TB) (c [5000]network.Conn) { |
|||
for i := range c { |
|||
c[i] = randConn(tb, nil) |
|||
} |
|||
return c |
|||
} |
|||
|
|||
func BenchmarkLockContention(b *testing.B) { |
|||
conns := randomConns(b) |
|||
cm, err := NewConnManager(1000, 1000, WithGracePeriod(0)) |
|||
require.NoError(b, err) |
|||
not := cm.Notifee() |
|||
|
|||
kill := make(chan struct{}) |
|||
var wg sync.WaitGroup |
|||
|
|||
for i := 0; i < 16; i++ { |
|||
wg.Add(1) |
|||
go func() { |
|||
defer wg.Done() |
|||
for { |
|||
select { |
|||
case <-kill: |
|||
return |
|||
default: |
|||
cm.TagPeer(conns[rand.Intn(len(conns))].RemotePeer(), "another-tag", 1) |
|||
} |
|||
} |
|||
}() |
|||
} |
|||
|
|||
b.ResetTimer() |
|||
for i := 0; i < b.N; i++ { |
|||
rc := conns[rand.Intn(len(conns))] |
|||
not.Connected(nil, rc) |
|||
cm.TagPeer(rc.RemotePeer(), "tag", 100) |
|||
cm.UntagPeer(rc.RemotePeer(), "tag") |
|||
not.Disconnected(nil, rc) |
|||
} |
|||
close(kill) |
|||
wg.Wait() |
|||
} |
@ -0,0 +1,694 @@ |
|||
package connmgr |
|||
|
|||
import ( |
|||
"context" |
|||
"sort" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/connmgr" |
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
|
|||
logging "github.com/ipfs/go-log/v2" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
) |
|||
|
|||
var log = logging.Logger("connmgr") |
|||
|
|||
// BasicConnMgr is a ConnManager that trims connections whenever the count exceeds the
|
|||
// high watermark. New connections are given a grace period before they're subject
|
|||
// to trimming. Trims are automatically run on demand, only if the time from the
|
|||
// previous trim is higher than 10 seconds. Furthermore, trims can be explicitly
|
|||
// requested through the public interface of this struct (see TrimOpenConns).
|
|||
//
|
|||
// See configuration parameters in NewConnManager.
|
|||
type BasicConnMgr struct { |
|||
*decayer |
|||
|
|||
cfg *config |
|||
segments segments |
|||
|
|||
plk sync.RWMutex |
|||
protected map[peer.ID]map[string]struct{} |
|||
|
|||
// channel-based semaphore that enforces only a single trim is in progress
|
|||
trimMutex sync.Mutex |
|||
connCount int32 |
|||
// to be accessed atomically. This is mimicking the implementation of a sync.Once.
|
|||
// Take care of correct alignment when modifying this struct.
|
|||
trimCount uint64 |
|||
|
|||
lastTrimMu sync.RWMutex |
|||
lastTrim time.Time |
|||
|
|||
refCount sync.WaitGroup |
|||
ctx context.Context |
|||
cancel func() |
|||
unregisterMemoryWatcher func() |
|||
} |
|||
|
|||
var ( |
|||
_ connmgr.ConnManager = (*BasicConnMgr)(nil) |
|||
_ connmgr.Decayer = (*BasicConnMgr)(nil) |
|||
) |
|||
|
|||
type segment struct { |
|||
sync.Mutex |
|||
peers map[peer.ID]*peerInfo |
|||
} |
|||
|
|||
type segments [256]*segment |
|||
|
|||
func (ss *segments) get(p peer.ID) *segment { |
|||
return ss[byte(p[len(p)-1])] |
|||
} |
|||
|
|||
func (ss *segments) countPeers() (count int) { |
|||
for _, seg := range ss { |
|||
seg.Lock() |
|||
count += len(seg.peers) |
|||
seg.Unlock() |
|||
} |
|||
return count |
|||
} |
|||
|
|||
func (s *segment) tagInfoFor(p peer.ID) *peerInfo { |
|||
pi, ok := s.peers[p] |
|||
if ok { |
|||
return pi |
|||
} |
|||
// create a temporary peer to buffer early tags before the Connected notification arrives.
|
|||
pi = &peerInfo{ |
|||
id: p, |
|||
firstSeen: time.Now(), // this timestamp will be updated when the first Connected notification arrives.
|
|||
temp: true, |
|||
tags: make(map[string]int), |
|||
decaying: make(map[*decayingTag]*connmgr.DecayingValue), |
|||
conns: make(map[network.Conn]time.Time), |
|||
} |
|||
s.peers[p] = pi |
|||
return pi |
|||
} |
|||
|
|||
// NewConnManager creates a new BasicConnMgr with the provided params:
|
|||
// lo and hi are watermarks governing the number of connections that'll be maintained.
|
|||
// When the peer count exceeds the 'high watermark', as many peers will be pruned (and
|
|||
// their connections terminated) until 'low watermark' peers remain.
|
|||
func NewConnManager(low, hi int, opts ...Option) (*BasicConnMgr, error) { |
|||
cfg := &config{ |
|||
highWater: hi, |
|||
lowWater: low, |
|||
gracePeriod: time.Minute, |
|||
silencePeriod: 10 * time.Second, |
|||
} |
|||
for _, o := range opts { |
|||
if err := o(cfg); err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
if cfg.decayer == nil { |
|||
// Set the default decayer config.
|
|||
cfg.decayer = (&DecayerCfg{}).WithDefaults() |
|||
} |
|||
|
|||
cm := &BasicConnMgr{ |
|||
cfg: cfg, |
|||
protected: make(map[peer.ID]map[string]struct{}, 16), |
|||
segments: func() (ret segments) { |
|||
for i := range ret { |
|||
ret[i] = &segment{ |
|||
peers: make(map[peer.ID]*peerInfo), |
|||
} |
|||
} |
|||
return ret |
|||
}(), |
|||
} |
|||
cm.ctx, cm.cancel = context.WithCancel(context.Background()) |
|||
|
|||
if cfg.emergencyTrim { |
|||
// When we're running low on memory, immediately trigger a trim.
|
|||
cm.unregisterMemoryWatcher = registerWatchdog(cm.memoryEmergency) |
|||
} |
|||
|
|||
decay, _ := NewDecayer(cfg.decayer, cm) |
|||
cm.decayer = decay |
|||
|
|||
cm.refCount.Add(1) |
|||
go cm.background() |
|||
return cm, nil |
|||
} |
|||
|
|||
// memoryEmergency is run when we run low on memory.
|
|||
// Close connections until we right the low watermark.
|
|||
// We don't pay attention to the silence period or the grace period.
|
|||
// We try to not kill protected connections, but if that turns out to be necessary, not connection is safe!
|
|||
func (cm *BasicConnMgr) memoryEmergency() { |
|||
connCount := int(atomic.LoadInt32(&cm.connCount)) |
|||
target := connCount - cm.cfg.lowWater |
|||
if target < 0 { |
|||
log.Warnw("Low on memory, but we only have a few connections", "num", connCount, "low watermark", cm.cfg.lowWater) |
|||
return |
|||
} else { |
|||
log.Warnf("Low on memory. Closing %d connections.", target) |
|||
} |
|||
|
|||
cm.trimMutex.Lock() |
|||
defer atomic.AddUint64(&cm.trimCount, 1) |
|||
defer cm.trimMutex.Unlock() |
|||
|
|||
// Trim connections without paying attention to the silence period.
|
|||
for _, c := range cm.getConnsToCloseEmergency(target) { |
|||
log.Infow("low on memory. closing conn", "peer", c.RemotePeer()) |
|||
c.Close() |
|||
} |
|||
|
|||
// finally, update the last trim time.
|
|||
cm.lastTrimMu.Lock() |
|||
cm.lastTrim = time.Now() |
|||
cm.lastTrimMu.Unlock() |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) Close() error { |
|||
cm.cancel() |
|||
if cm.unregisterMemoryWatcher != nil { |
|||
cm.unregisterMemoryWatcher() |
|||
} |
|||
if err := cm.decayer.Close(); err != nil { |
|||
return err |
|||
} |
|||
cm.refCount.Wait() |
|||
return nil |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) Protect(id peer.ID, tag string) { |
|||
cm.plk.Lock() |
|||
defer cm.plk.Unlock() |
|||
|
|||
tags, ok := cm.protected[id] |
|||
if !ok { |
|||
tags = make(map[string]struct{}, 2) |
|||
cm.protected[id] = tags |
|||
} |
|||
tags[tag] = struct{}{} |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) Unprotect(id peer.ID, tag string) (protected bool) { |
|||
cm.plk.Lock() |
|||
defer cm.plk.Unlock() |
|||
|
|||
tags, ok := cm.protected[id] |
|||
if !ok { |
|||
return false |
|||
} |
|||
if delete(tags, tag); len(tags) == 0 { |
|||
delete(cm.protected, id) |
|||
return false |
|||
} |
|||
return true |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) IsProtected(id peer.ID, tag string) (protected bool) { |
|||
cm.plk.Lock() |
|||
defer cm.plk.Unlock() |
|||
|
|||
tags, ok := cm.protected[id] |
|||
if !ok { |
|||
return false |
|||
} |
|||
|
|||
if tag == "" { |
|||
return true |
|||
} |
|||
|
|||
_, protected = tags[tag] |
|||
return protected |
|||
} |
|||
|
|||
// peerInfo stores metadata for a given peer.
|
|||
type peerInfo struct { |
|||
id peer.ID |
|||
tags map[string]int // value for each tag
|
|||
decaying map[*decayingTag]*connmgr.DecayingValue // decaying tags
|
|||
|
|||
value int // cached sum of all tag values
|
|||
temp bool // this is a temporary entry holding early tags, and awaiting connections
|
|||
|
|||
conns map[network.Conn]time.Time // start time of each connection
|
|||
|
|||
firstSeen time.Time // timestamp when we began tracking this peer.
|
|||
} |
|||
|
|||
type peerInfos []peerInfo |
|||
|
|||
func (p peerInfos) SortByValue() { |
|||
sort.Slice(p, func(i, j int) bool { |
|||
left, right := p[i], p[j] |
|||
// temporary peers are preferred for pruning.
|
|||
if left.temp != right.temp { |
|||
return left.temp |
|||
} |
|||
// otherwise, compare by value.
|
|||
return left.value < right.value |
|||
}) |
|||
} |
|||
|
|||
func (p peerInfos) SortByValueAndStreams() { |
|||
sort.Slice(p, func(i, j int) bool { |
|||
left, right := p[i], p[j] |
|||
// temporary peers are preferred for pruning.
|
|||
if left.temp != right.temp { |
|||
return left.temp |
|||
} |
|||
// otherwise, compare by value.
|
|||
if left.value != right.value { |
|||
return left.value < right.value |
|||
} |
|||
incomingAndStreams := func(m map[network.Conn]time.Time) (incoming bool, numStreams int) { |
|||
for c := range m { |
|||
stat := c.Stat() |
|||
if stat.Direction == network.DirInbound { |
|||
incoming = true |
|||
} |
|||
numStreams += stat.NumStreams |
|||
} |
|||
return |
|||
} |
|||
leftIncoming, leftStreams := incomingAndStreams(left.conns) |
|||
rightIncoming, rightStreams := incomingAndStreams(right.conns) |
|||
// incoming connections are preferred for pruning
|
|||
if leftIncoming != rightIncoming { |
|||
return leftIncoming |
|||
} |
|||
// prune connections with a higher number of streams first
|
|||
return rightStreams < leftStreams |
|||
}) |
|||
} |
|||
|
|||
// TrimOpenConns closes the connections of as many peers as needed to make the peer count
|
|||
// equal the low watermark. Peers are sorted in ascending order based on their total value,
|
|||
// pruning those peers with the lowest scores first, as long as they are not within their
|
|||
// grace period.
|
|||
//
|
|||
// This function blocks until a trim is completed. If a trim is underway, a new
|
|||
// one won't be started, and instead it'll wait until that one is completed before
|
|||
// returning.
|
|||
func (cm *BasicConnMgr) TrimOpenConns(_ context.Context) { |
|||
// TODO: error return value so we can cleanly signal we are aborting because:
|
|||
// (a) there's another trim in progress, or (b) the silence period is in effect.
|
|||
|
|||
cm.doTrim() |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) background() { |
|||
defer cm.refCount.Done() |
|||
|
|||
interval := cm.cfg.gracePeriod / 2 |
|||
if cm.cfg.silencePeriod != 0 { |
|||
interval = cm.cfg.silencePeriod |
|||
} |
|||
|
|||
ticker := time.NewTicker(interval) |
|||
defer ticker.Stop() |
|||
|
|||
for { |
|||
select { |
|||
case <-ticker.C: |
|||
if atomic.LoadInt32(&cm.connCount) < int32(cm.cfg.highWater) { |
|||
// Below high water, skip.
|
|||
continue |
|||
} |
|||
case <-cm.ctx.Done(): |
|||
return |
|||
} |
|||
cm.trim() |
|||
} |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) doTrim() { |
|||
// This logic is mimicking the implementation of sync.Once in the standard library.
|
|||
count := atomic.LoadUint64(&cm.trimCount) |
|||
cm.trimMutex.Lock() |
|||
defer cm.trimMutex.Unlock() |
|||
if count == atomic.LoadUint64(&cm.trimCount) { |
|||
cm.trim() |
|||
cm.lastTrimMu.Lock() |
|||
cm.lastTrim = time.Now() |
|||
cm.lastTrimMu.Unlock() |
|||
atomic.AddUint64(&cm.trimCount, 1) |
|||
} |
|||
} |
|||
|
|||
// trim starts the trim, if the last trim happened before the configured silence period.
|
|||
func (cm *BasicConnMgr) trim() { |
|||
// do the actual trim.
|
|||
for _, c := range cm.getConnsToClose() { |
|||
log.Infow("closing conn", "peer", c.RemotePeer()) |
|||
c.Close() |
|||
} |
|||
} |
|||
|
|||
func (cm *BasicConnMgr) getConnsToCloseEmergency(target int) []network.Conn { |
|||
candidates := make(peerInfos, 0, cm.segments.countPeers()) |
|||
|
|||
cm.plk.RLock() |
|||
for _, s := range cm.segments { |
|||
s.Lock() |
|||
for id, inf := range s.peers { |
|||
if _, ok := cm.protected[id]; ok { |
|||
// skip over protected peer.
|
|||
continue |
|||
} |
|||
candidates = append(candidates, *inf) |
|||
} |
|||
s.Unlock() |
|||
} |
|||
cm.plk.RUnlock() |
|||
|
|||
// Sort peers according to their value.
|
|||
candidates.SortByValueAndStreams() |
|||
|
|||
selected := make([]network.Conn, 0, target+10) |
|||
for _, inf := range candidates { |
|||
if target <= 0 { |
|||
break |
|||
} |
|||
for c := range inf.conns { |
|||
selected = append(selected, c) |
|||
} |
|||
target -= len(inf.conns) |
|||
} |
|||
if len(selected) >= target { |
|||
// We found enough connections that were not protected.
|
|||
return selected |
|||
} |
|||
|
|||
// We didn't find enough unprotected connections.
|
|||
// We have no choice but to kill some protected connections.
|
|||
candidates = candidates[:0] |
|||
cm.plk.RLock() |
|||
for _, s := range cm.segments { |
|||
s.Lock() |
|||
for _, inf := range s.peers { |
|||
candidates = append(candidates, *inf) |
|||
} |
|||
s.Unlock() |
|||
} |
|||
cm.plk.RUnlock() |
|||
|
|||
candidates.SortByValueAndStreams() |
|||
for _, inf := range candidates { |
|||
if target <= 0 { |
|||
break |
|||
} |
|||
for c := range inf.conns { |
|||
selected = append(selected, c) |
|||
} |
|||
target -= len(inf.conns) |
|||
} |
|||
return selected |
|||
} |
|||
|
|||
// getConnsToClose runs the heuristics described in TrimOpenConns and returns the
|
|||
// connections to close.
|
|||
func (cm *BasicConnMgr) getConnsToClose() []network.Conn { |
|||
if cm.cfg.lowWater == 0 || cm.cfg.highWater == 0 { |
|||
// disabled
|
|||
return nil |
|||
} |
|||
|
|||
if int(atomic.LoadInt32(&cm.connCount)) <= cm.cfg.lowWater { |
|||
log.Info("open connection count below limit") |
|||
return nil |
|||
} |
|||
|
|||
candidates := make(peerInfos, 0, cm.segments.countPeers()) |
|||
var ncandidates int |
|||
gracePeriodStart := time.Now().Add(-cm.cfg.gracePeriod) |
|||
|
|||
cm.plk.RLock() |
|||
for _, s := range cm.segments { |
|||
s.Lock() |
|||
for id, inf := range s.peers { |
|||
if _, ok := cm.protected[id]; ok { |
|||
// skip over protected peer.
|
|||
continue |
|||
} |
|||
if inf.firstSeen.After(gracePeriodStart) { |
|||
// skip peers in the grace period.
|
|||
continue |
|||
} |
|||
// note that we're copying the entry here,
|
|||
// but since inf.conns is a map, it will still point to the original object
|
|||
candidates = append(candidates, *inf) |
|||
ncandidates += len(inf.conns) |
|||
} |
|||
s.Unlock() |
|||
} |
|||
cm.plk.RUnlock() |
|||
|
|||
if ncandidates < cm.cfg.lowWater { |
|||
log.Info("open connection count above limit but too many are in the grace period") |
|||
// We have too many connections but fewer than lowWater
|
|||
// connections out of the grace period.
|
|||
//
|
|||
// If we trimmed now, we'd kill potentially useful connections.
|
|||
return nil |
|||
} |
|||
|
|||
// Sort peers according to their value.
|
|||
candidates.SortByValue() |
|||
|
|||
target := ncandidates - cm.cfg.lowWater |
|||
|
|||
// slightly overallocate because we may have more than one conns per peer
|
|||
selected := make([]network.Conn, 0, target+10) |
|||
|
|||
for _, inf := range candidates { |
|||
if target <= 0 { |
|||
break |
|||
} |
|||
|
|||
// lock this to protect from concurrent modifications from connect/disconnect events
|
|||
s := cm.segments.get(inf.id) |
|||
s.Lock() |
|||
if len(inf.conns) == 0 && inf.temp { |
|||
// handle temporary entries for early tags -- this entry has gone past the grace period
|
|||
// and still holds no connections, so prune it.
|
|||
delete(s.peers, inf.id) |
|||
} else { |
|||
for c := range inf.conns { |
|||
selected = append(selected, c) |
|||
} |
|||
target -= len(inf.conns) |
|||
} |
|||
s.Unlock() |
|||
} |
|||
|
|||
return selected |
|||
} |
|||
|
|||
// GetTagInfo is called to fetch the tag information associated with a given
|
|||
// peer, nil is returned if p refers to an unknown peer.
|
|||
func (cm *BasicConnMgr) GetTagInfo(p peer.ID) *connmgr.TagInfo { |
|||
s := cm.segments.get(p) |
|||
s.Lock() |
|||
defer s.Unlock() |
|||
|
|||
pi, ok := s.peers[p] |
|||
if !ok { |
|||
return nil |
|||
} |
|||
|
|||
out := &connmgr.TagInfo{ |
|||
FirstSeen: pi.firstSeen, |
|||
Value: pi.value, |
|||
Tags: make(map[string]int), |
|||
Conns: make(map[string]time.Time), |
|||
} |
|||
|
|||
for t, v := range pi.tags { |
|||
out.Tags[t] = v |
|||
} |
|||
for t, v := range pi.decaying { |
|||
out.Tags[t.name] = v.Value |
|||
} |
|||
for c, t := range pi.conns { |
|||
out.Conns[c.RemoteMultiaddr().String()] = t |
|||
} |
|||
|
|||
return out |
|||
} |
|||
|
|||
// TagPeer is called to associate a string and integer with a given peer.
|
|||
func (cm *BasicConnMgr) TagPeer(p peer.ID, tag string, val int) { |
|||
s := cm.segments.get(p) |
|||
s.Lock() |
|||
defer s.Unlock() |
|||
|
|||
pi := s.tagInfoFor(p) |
|||
|
|||
// Update the total value of the peer.
|
|||
pi.value += val - pi.tags[tag] |
|||
pi.tags[tag] = val |
|||
} |
|||
|
|||
// UntagPeer is called to disassociate a string and integer from a given peer.
|
|||
func (cm *BasicConnMgr) UntagPeer(p peer.ID, tag string) { |
|||
s := cm.segments.get(p) |
|||
s.Lock() |
|||
defer s.Unlock() |
|||
|
|||
pi, ok := s.peers[p] |
|||
if !ok { |
|||
log.Info("tried to remove tag from untracked peer: ", p) |
|||
return |
|||
} |
|||
|
|||
// Update the total value of the peer.
|
|||
pi.value -= pi.tags[tag] |
|||
delete(pi.tags, tag) |
|||
} |
|||
|
|||
// UpsertTag is called to insert/update a peer tag
|
|||
func (cm *BasicConnMgr) UpsertTag(p peer.ID, tag string, upsert func(int) int) { |
|||
s := cm.segments.get(p) |
|||
s.Lock() |
|||
defer s.Unlock() |
|||
|
|||
pi := s.tagInfoFor(p) |
|||
|
|||
oldval := pi.tags[tag] |
|||
newval := upsert(oldval) |
|||
pi.value += newval - oldval |
|||
pi.tags[tag] = newval |
|||
} |
|||
|
|||
// CMInfo holds the configuration for BasicConnMgr, as well as status data.
|
|||
type CMInfo struct { |
|||
// The low watermark, as described in NewConnManager.
|
|||
LowWater int |
|||
|
|||
// The high watermark, as described in NewConnManager.
|
|||
HighWater int |
|||
|
|||
// The timestamp when the last trim was triggered.
|
|||
LastTrim time.Time |
|||
|
|||
// The configured grace period, as described in NewConnManager.
|
|||
GracePeriod time.Duration |
|||
|
|||
// The current connection count.
|
|||
ConnCount int |
|||
} |
|||
|
|||
// GetInfo returns the configuration and status data for this connection manager.
|
|||
func (cm *BasicConnMgr) GetInfo() CMInfo { |
|||
cm.lastTrimMu.RLock() |
|||
lastTrim := cm.lastTrim |
|||
cm.lastTrimMu.RUnlock() |
|||
|
|||
return CMInfo{ |
|||
HighWater: cm.cfg.highWater, |
|||
LowWater: cm.cfg.lowWater, |
|||
LastTrim: lastTrim, |
|||
GracePeriod: cm.cfg.gracePeriod, |
|||
ConnCount: int(atomic.LoadInt32(&cm.connCount)), |
|||
} |
|||
} |
|||
|
|||
// Notifee returns a sink through which Notifiers can inform the BasicConnMgr when
|
|||
// events occur. Currently, the notifee only reacts upon connection events
|
|||
// {Connected, Disconnected}.
|
|||
func (cm *BasicConnMgr) Notifee() network.Notifiee { |
|||
return (*cmNotifee)(cm) |
|||
} |
|||
|
|||
type cmNotifee BasicConnMgr |
|||
|
|||
func (nn *cmNotifee) cm() *BasicConnMgr { |
|||
return (*BasicConnMgr)(nn) |
|||
} |
|||
|
|||
// Connected is called by notifiers to inform that a new connection has been established.
|
|||
// The notifee updates the BasicConnMgr to start tracking the connection. If the new connection
|
|||
// count exceeds the high watermark, a trim may be triggered.
|
|||
func (nn *cmNotifee) Connected(n network.Network, c network.Conn) { |
|||
cm := nn.cm() |
|||
|
|||
p := c.RemotePeer() |
|||
s := cm.segments.get(p) |
|||
s.Lock() |
|||
defer s.Unlock() |
|||
|
|||
id := c.RemotePeer() |
|||
pinfo, ok := s.peers[id] |
|||
if !ok { |
|||
pinfo = &peerInfo{ |
|||
id: id, |
|||
firstSeen: time.Now(), |
|||
tags: make(map[string]int), |
|||
decaying: make(map[*decayingTag]*connmgr.DecayingValue), |
|||
conns: make(map[network.Conn]time.Time), |
|||
} |
|||
s.peers[id] = pinfo |
|||
} else if pinfo.temp { |
|||
// we had created a temporary entry for this peer to buffer early tags before the
|
|||
// Connected notification arrived: flip the temporary flag, and update the firstSeen
|
|||
// timestamp to the real one.
|
|||
pinfo.temp = false |
|||
pinfo.firstSeen = time.Now() |
|||
} |
|||
|
|||
_, ok = pinfo.conns[c] |
|||
if ok { |
|||
log.Error("received connected notification for conn we are already tracking: ", p) |
|||
return |
|||
} |
|||
|
|||
pinfo.conns[c] = time.Now() |
|||
atomic.AddInt32(&cm.connCount, 1) |
|||
} |
|||
|
|||
// Disconnected is called by notifiers to inform that an existing connection has been closed or terminated.
|
|||
// The notifee updates the BasicConnMgr accordingly to stop tracking the connection, and performs housekeeping.
|
|||
func (nn *cmNotifee) Disconnected(n network.Network, c network.Conn) { |
|||
cm := nn.cm() |
|||
|
|||
p := c.RemotePeer() |
|||
s := cm.segments.get(p) |
|||
s.Lock() |
|||
defer s.Unlock() |
|||
|
|||
cinf, ok := s.peers[p] |
|||
if !ok { |
|||
log.Error("received disconnected notification for peer we are not tracking: ", p) |
|||
return |
|||
} |
|||
|
|||
_, ok = cinf.conns[c] |
|||
if !ok { |
|||
log.Error("received disconnected notification for conn we are not tracking: ", p) |
|||
return |
|||
} |
|||
|
|||
delete(cinf.conns, c) |
|||
if len(cinf.conns) == 0 { |
|||
delete(s.peers, p) |
|||
} |
|||
atomic.AddInt32(&cm.connCount, -1) |
|||
} |
|||
|
|||
// Listen is no-op in this implementation.
|
|||
func (nn *cmNotifee) Listen(n network.Network, addr ma.Multiaddr) {} |
|||
|
|||
// ListenClose is no-op in this implementation.
|
|||
func (nn *cmNotifee) ListenClose(n network.Network, addr ma.Multiaddr) {} |
|||
|
|||
// OpenedStream is no-op in this implementation.
|
|||
func (nn *cmNotifee) OpenedStream(network.Network, network.Stream) {} |
|||
|
|||
// ClosedStream is no-op in this implementation.
|
|||
func (nn *cmNotifee) ClosedStream(network.Network, network.Stream) {} |
@ -0,0 +1,857 @@ |
|||
package connmgr |
|||
|
|||
import ( |
|||
"context" |
|||
"sync" |
|||
"sync/atomic" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/crypto" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/network" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
|
|||
tu "github.com/libp2p/go-libp2p-core/test" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
|
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
type tconn struct { |
|||
network.Conn |
|||
|
|||
peer peer.ID |
|||
closed uint32 // to be used atomically. Closed if 1
|
|||
disconnectNotify func(net network.Network, conn network.Conn) |
|||
} |
|||
|
|||
func (c *tconn) Close() error { |
|||
atomic.StoreUint32(&c.closed, 1) |
|||
if c.disconnectNotify != nil { |
|||
c.disconnectNotify(nil, c) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (c *tconn) isClosed() bool { |
|||
return atomic.LoadUint32(&c.closed) == 1 |
|||
} |
|||
|
|||
func (c *tconn) RemotePeer() peer.ID { |
|||
return c.peer |
|||
} |
|||
|
|||
func (c *tconn) RemoteMultiaddr() ma.Multiaddr { |
|||
addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234") |
|||
if err != nil { |
|||
panic("cannot create multiaddr") |
|||
} |
|||
return addr |
|||
} |
|||
|
|||
func randConn(t testing.TB, discNotify func(network.Network, network.Conn)) network.Conn { |
|||
pid := tu.RandPeerIDFatal(t) |
|||
return &tconn{peer: pid, disconnectNotify: discNotify} |
|||
} |
|||
|
|||
// Make sure multiple trim calls block.
|
|||
func TestTrimBlocks(t *testing.T) { |
|||
cm, err := NewConnManager(200, 300, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
cm.lastTrimMu.RLock() |
|||
|
|||
doneCh := make(chan struct{}, 2) |
|||
go func() { |
|||
cm.TrimOpenConns(context.Background()) |
|||
doneCh <- struct{}{} |
|||
}() |
|||
go func() { |
|||
cm.TrimOpenConns(context.Background()) |
|||
doneCh <- struct{}{} |
|||
}() |
|||
time.Sleep(time.Millisecond) |
|||
select { |
|||
case <-doneCh: |
|||
cm.lastTrimMu.RUnlock() |
|||
t.Fatal("expected trim to block") |
|||
default: |
|||
cm.lastTrimMu.RUnlock() |
|||
} |
|||
<-doneCh |
|||
<-doneCh |
|||
} |
|||
|
|||
// Make sure trim returns when closed.
|
|||
func TestTrimClosed(t *testing.T) { |
|||
cm, err := NewConnManager(200, 300, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
require.NoError(t, cm.Close()) |
|||
cm.TrimOpenConns(context.Background()) |
|||
} |
|||
|
|||
// Make sure joining an existing trim works.
|
|||
func TestTrimJoin(t *testing.T) { |
|||
cm, err := NewConnManager(200, 300, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
cm.lastTrimMu.RLock() |
|||
var wg sync.WaitGroup |
|||
wg.Add(3) |
|||
go func() { |
|||
defer wg.Done() |
|||
cm.TrimOpenConns(context.Background()) |
|||
}() |
|||
time.Sleep(time.Millisecond) |
|||
go func() { |
|||
defer wg.Done() |
|||
cm.TrimOpenConns(context.Background()) |
|||
}() |
|||
go func() { |
|||
defer wg.Done() |
|||
cm.TrimOpenConns(context.Background()) |
|||
}() |
|||
time.Sleep(time.Millisecond) |
|||
cm.lastTrimMu.RUnlock() |
|||
wg.Wait() |
|||
} |
|||
|
|||
func TestConnTrimming(t *testing.T) { |
|||
cm, err := NewConnManager(200, 300, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
|
|||
var conns []network.Conn |
|||
for i := 0; i < 300; i++ { |
|||
rc := randConn(t, nil) |
|||
conns = append(conns, rc) |
|||
not.Connected(nil, rc) |
|||
} |
|||
|
|||
for _, c := range conns { |
|||
if c.(*tconn).isClosed() { |
|||
t.Fatal("nothing should be closed yet") |
|||
} |
|||
} |
|||
|
|||
for i := 0; i < 100; i++ { |
|||
cm.TagPeer(conns[i].RemotePeer(), "foo", 10) |
|||
} |
|||
|
|||
cm.TagPeer(conns[299].RemotePeer(), "badfoo", -5) |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
for i := 0; i < 100; i++ { |
|||
c := conns[i] |
|||
if c.(*tconn).isClosed() { |
|||
t.Fatal("these shouldnt be closed") |
|||
} |
|||
} |
|||
|
|||
if !conns[299].(*tconn).isClosed() { |
|||
t.Fatal("conn with bad tag should have gotten closed") |
|||
} |
|||
} |
|||
|
|||
func TestConnsToClose(t *testing.T) { |
|||
addConns := func(cm *BasicConnMgr, n int) { |
|||
not := cm.Notifee() |
|||
for i := 0; i < n; i++ { |
|||
conn := randConn(t, nil) |
|||
not.Connected(nil, conn) |
|||
} |
|||
} |
|||
|
|||
t.Run("below hi limit", func(t *testing.T) { |
|||
cm, err := NewConnManager(0, 10, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
addConns(cm, 5) |
|||
require.Empty(t, cm.getConnsToClose()) |
|||
}) |
|||
|
|||
t.Run("below low limit", func(t *testing.T) { |
|||
cm, err := NewConnManager(10, 0, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
addConns(cm, 5) |
|||
require.Empty(t, cm.getConnsToClose()) |
|||
}) |
|||
|
|||
t.Run("below low and hi limit", func(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
addConns(cm, 1) |
|||
require.Empty(t, cm.getConnsToClose()) |
|||
}) |
|||
|
|||
t.Run("within silence period", func(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
addConns(cm, 1) |
|||
require.Empty(t, cm.getConnsToClose()) |
|||
}) |
|||
} |
|||
|
|||
func TestGetTagInfo(t *testing.T) { |
|||
start := time.Now() |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
not := cm.Notifee() |
|||
conn := randConn(t, nil) |
|||
not.Connected(nil, conn) |
|||
end := time.Now() |
|||
|
|||
other := tu.RandPeerIDFatal(t) |
|||
tag := cm.GetTagInfo(other) |
|||
if tag != nil { |
|||
t.Fatal("expected no tag") |
|||
} |
|||
|
|||
tag = cm.GetTagInfo(conn.RemotePeer()) |
|||
if tag == nil { |
|||
t.Fatal("expected tag") |
|||
} |
|||
if tag.FirstSeen.Before(start) || tag.FirstSeen.After(end) { |
|||
t.Fatal("expected first seen time") |
|||
} |
|||
if tag.Value != 0 { |
|||
t.Fatal("expected zero value") |
|||
} |
|||
if len(tag.Tags) != 0 { |
|||
t.Fatal("expected no tags") |
|||
} |
|||
if len(tag.Conns) != 1 { |
|||
t.Fatal("expected one connection") |
|||
} |
|||
for s, tm := range tag.Conns { |
|||
if s != conn.RemoteMultiaddr().String() { |
|||
t.Fatal("unexpected multiaddr") |
|||
} |
|||
if tm.Before(start) || tm.After(end) { |
|||
t.Fatal("unexpected connection time") |
|||
} |
|||
} |
|||
|
|||
cm.TagPeer(conn.RemotePeer(), "tag", 5) |
|||
tag = cm.GetTagInfo(conn.RemotePeer()) |
|||
if tag == nil { |
|||
t.Fatal("expected tag") |
|||
} |
|||
if tag.FirstSeen.Before(start) || tag.FirstSeen.After(end) { |
|||
t.Fatal("expected first seen time") |
|||
} |
|||
if tag.Value != 5 { |
|||
t.Fatal("expected five value") |
|||
} |
|||
if len(tag.Tags) != 1 { |
|||
t.Fatal("expected no tags") |
|||
} |
|||
for tString, v := range tag.Tags { |
|||
if tString != "tag" || v != 5 { |
|||
t.Fatal("expected tag value") |
|||
} |
|||
} |
|||
if len(tag.Conns) != 1 { |
|||
t.Fatal("expected one connection") |
|||
} |
|||
for s, tm := range tag.Conns { |
|||
if s != conn.RemoteMultiaddr().String() { |
|||
t.Fatal("unexpected multiaddr") |
|||
} |
|||
if tm.Before(start) || tm.After(end) { |
|||
t.Fatal("unexpected connection time") |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestTagPeerNonExistant(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
id := tu.RandPeerIDFatal(t) |
|||
cm.TagPeer(id, "test", 1) |
|||
|
|||
if !cm.segments.get(id).peers[id].temp { |
|||
t.Fatal("expected 1 temporary entry") |
|||
} |
|||
} |
|||
|
|||
func TestUntagPeer(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
|
|||
conn := randConn(t, nil) |
|||
not.Connected(nil, conn) |
|||
rp := conn.RemotePeer() |
|||
cm.TagPeer(rp, "tag", 5) |
|||
cm.TagPeer(rp, "tag two", 5) |
|||
|
|||
id := tu.RandPeerIDFatal(t) |
|||
cm.UntagPeer(id, "test") |
|||
if len(cm.segments.get(rp).peers[rp].tags) != 2 { |
|||
t.Fatal("expected tags to be uneffected") |
|||
} |
|||
|
|||
cm.UntagPeer(conn.RemotePeer(), "test") |
|||
if len(cm.segments.get(rp).peers[rp].tags) != 2 { |
|||
t.Fatal("expected tags to be uneffected") |
|||
} |
|||
|
|||
cm.UntagPeer(conn.RemotePeer(), "tag") |
|||
if len(cm.segments.get(rp).peers[rp].tags) != 1 { |
|||
t.Fatal("expected tag to be removed") |
|||
} |
|||
if cm.segments.get(rp).peers[rp].value != 5 { |
|||
t.Fatal("expected aggreagte tag value to be 5") |
|||
} |
|||
} |
|||
|
|||
func TestGetInfo(t *testing.T) { |
|||
start := time.Now() |
|||
const gp = 10 * time.Minute |
|||
cm, err := NewConnManager(1, 5, WithGracePeriod(gp)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
conn := randConn(t, nil) |
|||
not.Connected(nil, conn) |
|||
cm.TrimOpenConns(context.Background()) |
|||
end := time.Now() |
|||
|
|||
info := cm.GetInfo() |
|||
if info.HighWater != 5 { |
|||
t.Fatal("expected highwater to be 5") |
|||
} |
|||
if info.LowWater != 1 { |
|||
t.Fatal("expected highwater to be 1") |
|||
} |
|||
if info.LastTrim.Before(start) || info.LastTrim.After(end) { |
|||
t.Fatal("unexpected last trim time") |
|||
} |
|||
if info.GracePeriod != gp { |
|||
t.Fatal("unexpected grace period") |
|||
} |
|||
if info.ConnCount != 1 { |
|||
t.Fatal("unexpected number of connections") |
|||
} |
|||
} |
|||
|
|||
func TestDoubleConnection(t *testing.T) { |
|||
const gp = 10 * time.Minute |
|||
cm, err := NewConnManager(1, 5, WithGracePeriod(gp)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
conn := randConn(t, nil) |
|||
not.Connected(nil, conn) |
|||
cm.TagPeer(conn.RemotePeer(), "foo", 10) |
|||
not.Connected(nil, conn) |
|||
if cm.connCount != 1 { |
|||
t.Fatal("unexpected number of connections") |
|||
} |
|||
if cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()].value != 10 { |
|||
t.Fatal("unexpected peer value") |
|||
} |
|||
} |
|||
|
|||
func TestDisconnected(t *testing.T) { |
|||
const gp = 10 * time.Minute |
|||
cm, err := NewConnManager(1, 5, WithGracePeriod(gp)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
conn := randConn(t, nil) |
|||
not.Connected(nil, conn) |
|||
cm.TagPeer(conn.RemotePeer(), "foo", 10) |
|||
|
|||
not.Disconnected(nil, randConn(t, nil)) |
|||
if cm.connCount != 1 { |
|||
t.Fatal("unexpected number of connections") |
|||
} |
|||
if cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()].value != 10 { |
|||
t.Fatal("unexpected peer value") |
|||
} |
|||
|
|||
not.Disconnected(nil, &tconn{peer: conn.RemotePeer()}) |
|||
if cm.connCount != 1 { |
|||
t.Fatal("unexpected number of connections") |
|||
} |
|||
if cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()].value != 10 { |
|||
t.Fatal("unexpected peer value") |
|||
} |
|||
|
|||
not.Disconnected(nil, conn) |
|||
if cm.connCount != 0 { |
|||
t.Fatal("unexpected number of connections") |
|||
} |
|||
if cm.segments.countPeers() != 0 { |
|||
t.Fatal("unexpected number of peers") |
|||
} |
|||
} |
|||
|
|||
func TestGracePeriod(t *testing.T) { |
|||
const gp = 100 * time.Millisecond |
|||
cm, err := NewConnManager(10, 20, WithGracePeriod(gp), WithSilencePeriod(time.Hour)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
not := cm.Notifee() |
|||
|
|||
var conns []network.Conn |
|||
|
|||
// Add a connection and wait the grace period.
|
|||
{ |
|||
rc := randConn(t, not.Disconnected) |
|||
conns = append(conns, rc) |
|||
not.Connected(nil, rc) |
|||
|
|||
time.Sleep(2 * gp) |
|||
|
|||
if rc.(*tconn).isClosed() { |
|||
t.Fatal("expected conn to remain open") |
|||
} |
|||
} |
|||
|
|||
// quickly add 30 connections (sending us above the high watermark)
|
|||
for i := 0; i < 30; i++ { |
|||
rc := randConn(t, not.Disconnected) |
|||
conns = append(conns, rc) |
|||
not.Connected(nil, rc) |
|||
} |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
for _, c := range conns { |
|||
if c.(*tconn).isClosed() { |
|||
t.Fatal("expected no conns to be closed") |
|||
} |
|||
} |
|||
|
|||
time.Sleep(200 * time.Millisecond) |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
closed := 0 |
|||
for _, c := range conns { |
|||
if c.(*tconn).isClosed() { |
|||
closed++ |
|||
} |
|||
} |
|||
|
|||
if closed != 21 { |
|||
t.Fatal("expected to have closed 21 connections") |
|||
} |
|||
} |
|||
|
|||
// see https://github.com/libp2p/go-libp2p-connmgr/issues/23
|
|||
func TestQuickBurstRespectsSilencePeriod(t *testing.T) { |
|||
cm, err := NewConnManager(10, 20, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
|
|||
var conns []network.Conn |
|||
|
|||
// quickly produce 30 connections (sending us above the high watermark)
|
|||
for i := 0; i < 30; i++ { |
|||
rc := randConn(t, not.Disconnected) |
|||
conns = append(conns, rc) |
|||
not.Connected(nil, rc) |
|||
} |
|||
|
|||
// wait for a few seconds
|
|||
time.Sleep(time.Second * 3) |
|||
|
|||
// only the first trim is allowed in; make sure we close at most 20 connections, not all of them.
|
|||
var closed int |
|||
for _, c := range conns { |
|||
if c.(*tconn).isClosed() { |
|||
closed++ |
|||
} |
|||
} |
|||
if closed > 20 { |
|||
t.Fatalf("should have closed at most 20 connections, closed: %d", closed) |
|||
} |
|||
if total := closed + int(cm.connCount); total != 30 { |
|||
t.Fatalf("expected closed connections + open conn count to equal 30, value: %d", total) |
|||
} |
|||
} |
|||
|
|||
func TestPeerProtectionSingleTag(t *testing.T) { |
|||
cm, err := NewConnManager(19, 20, WithGracePeriod(0), WithSilencePeriod(time.Hour)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
|
|||
var conns []network.Conn |
|||
addConn := func(value int) { |
|||
rc := randConn(t, not.Disconnected) |
|||
conns = append(conns, rc) |
|||
not.Connected(nil, rc) |
|||
cm.TagPeer(rc.RemotePeer(), "test", value) |
|||
} |
|||
|
|||
// produce 20 connections with unique peers.
|
|||
for i := 0; i < 20; i++ { |
|||
addConn(20) |
|||
} |
|||
|
|||
// protect the first 5 peers.
|
|||
var protected []network.Conn |
|||
for _, c := range conns[0:5] { |
|||
cm.Protect(c.RemotePeer(), "global") |
|||
protected = append(protected, c) |
|||
// tag them negatively to make them preferred for pruning.
|
|||
cm.TagPeer(c.RemotePeer(), "test", -100) |
|||
} |
|||
|
|||
// add 1 more conn, this shouldn't send us over the limit as protected conns don't count
|
|||
addConn(20) |
|||
|
|||
time.Sleep(100 * time.Millisecond) |
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
for _, c := range conns { |
|||
if c.(*tconn).isClosed() { |
|||
t.Error("connection was closed by connection manager") |
|||
} |
|||
} |
|||
|
|||
// add 5 more connection, sending the connection manager overboard.
|
|||
for i := 0; i < 5; i++ { |
|||
addConn(20) |
|||
} |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
for _, c := range protected { |
|||
if c.(*tconn).isClosed() { |
|||
t.Error("protected connection was closed by connection manager") |
|||
} |
|||
} |
|||
|
|||
closed := 0 |
|||
for _, c := range conns { |
|||
if c.(*tconn).isClosed() { |
|||
closed++ |
|||
} |
|||
} |
|||
if closed != 2 { |
|||
t.Errorf("expected 2 connection to be closed, found %d", closed) |
|||
} |
|||
|
|||
// unprotect the first peer.
|
|||
cm.Unprotect(protected[0].RemotePeer(), "global") |
|||
|
|||
// add 2 more connections, sending the connection manager overboard again.
|
|||
for i := 0; i < 2; i++ { |
|||
addConn(20) |
|||
} |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
if !protected[0].(*tconn).isClosed() { |
|||
t.Error("unprotected connection was kept open by connection manager") |
|||
} |
|||
for _, c := range protected[1:] { |
|||
if c.(*tconn).isClosed() { |
|||
t.Error("protected connection was closed by connection manager") |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestPeerProtectionMultipleTags(t *testing.T) { |
|||
cm, err := NewConnManager(19, 20, WithGracePeriod(0), WithSilencePeriod(time.Hour)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
|
|||
// produce 20 connections with unique peers.
|
|||
var conns []network.Conn |
|||
for i := 0; i < 20; i++ { |
|||
rc := randConn(t, not.Disconnected) |
|||
conns = append(conns, rc) |
|||
not.Connected(nil, rc) |
|||
cm.TagPeer(rc.RemotePeer(), "test", 20) |
|||
} |
|||
|
|||
// protect the first 5 peers under two tags.
|
|||
var protected []network.Conn |
|||
for _, c := range conns[0:5] { |
|||
cm.Protect(c.RemotePeer(), "tag1") |
|||
cm.Protect(c.RemotePeer(), "tag2") |
|||
protected = append(protected, c) |
|||
// tag them negatively to make them preferred for pruning.
|
|||
cm.TagPeer(c.RemotePeer(), "test", -100) |
|||
} |
|||
|
|||
// add one more connection, sending the connection manager overboard.
|
|||
not.Connected(nil, randConn(t, not.Disconnected)) |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
for _, c := range protected { |
|||
if c.(*tconn).isClosed() { |
|||
t.Error("protected connection was closed by connection manager") |
|||
} |
|||
} |
|||
|
|||
// remove the protection from one tag.
|
|||
for _, c := range protected { |
|||
if !cm.Unprotect(c.RemotePeer(), "tag1") { |
|||
t.Error("peer should still be protected") |
|||
} |
|||
} |
|||
|
|||
// add 2 more connections, sending the connection manager overboard again.
|
|||
for i := 0; i < 2; i++ { |
|||
rc := randConn(t, not.Disconnected) |
|||
not.Connected(nil, rc) |
|||
cm.TagPeer(rc.RemotePeer(), "test", 20) |
|||
} |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
// connections should still remain open, as they were protected.
|
|||
for _, c := range protected[0:] { |
|||
if c.(*tconn).isClosed() { |
|||
t.Error("protected connection was closed by connection manager") |
|||
} |
|||
} |
|||
|
|||
// unprotect the first peer entirely.
|
|||
cm.Unprotect(protected[0].RemotePeer(), "tag2") |
|||
|
|||
// add 2 more connections, sending the connection manager overboard again.
|
|||
for i := 0; i < 2; i++ { |
|||
rc := randConn(t, not.Disconnected) |
|||
not.Connected(nil, rc) |
|||
cm.TagPeer(rc.RemotePeer(), "test", 20) |
|||
} |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
|
|||
if !protected[0].(*tconn).isClosed() { |
|||
t.Error("unprotected connection was kept open by connection manager") |
|||
} |
|||
for _, c := range protected[1:] { |
|||
if c.(*tconn).isClosed() { |
|||
t.Error("protected connection was closed by connection manager") |
|||
} |
|||
} |
|||
|
|||
} |
|||
|
|||
func TestPeerProtectionIdempotent(t *testing.T) { |
|||
cm, err := NewConnManager(10, 20, WithGracePeriod(0), WithSilencePeriod(time.Hour)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
id, _ := tu.RandPeerID() |
|||
cm.Protect(id, "global") |
|||
cm.Protect(id, "global") |
|||
cm.Protect(id, "global") |
|||
cm.Protect(id, "global") |
|||
|
|||
if len(cm.protected[id]) > 1 { |
|||
t.Error("expected peer to be protected only once") |
|||
} |
|||
|
|||
if !cm.Unprotect(id, "unused") { |
|||
t.Error("expected peer to continue to be protected") |
|||
} |
|||
|
|||
if !cm.Unprotect(id, "unused2") { |
|||
t.Error("expected peer to continue to be protected") |
|||
} |
|||
|
|||
if cm.Unprotect(id, "global") { |
|||
t.Error("expected peer to be unprotected") |
|||
} |
|||
|
|||
if len(cm.protected) > 0 { |
|||
t.Error("expected no protections") |
|||
} |
|||
} |
|||
|
|||
func TestUpsertTag(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(10*time.Minute)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
not := cm.Notifee() |
|||
conn := randConn(t, nil) |
|||
rp := conn.RemotePeer() |
|||
|
|||
// this is an early tag, before the Connected notification arrived.
|
|||
cm.UpsertTag(rp, "tag", func(v int) int { return v + 1 }) |
|||
if len(cm.segments.get(rp).peers[rp].tags) != 1 { |
|||
t.Fatal("expected a tag") |
|||
} |
|||
if cm.segments.get(rp).peers[rp].value != 1 { |
|||
t.Fatal("expected a tag value of 1") |
|||
} |
|||
|
|||
// now let's notify the connection.
|
|||
not.Connected(nil, conn) |
|||
|
|||
cm.UpsertTag(rp, "tag", func(v int) int { return v + 1 }) |
|||
if len(cm.segments.get(rp).peers[rp].tags) != 1 { |
|||
t.Fatal("expected a tag") |
|||
} |
|||
if cm.segments.get(rp).peers[rp].value != 2 { |
|||
t.Fatal("expected a tag value of 2") |
|||
} |
|||
|
|||
cm.UpsertTag(rp, "tag", func(v int) int { return v - 1 }) |
|||
if len(cm.segments.get(rp).peers[rp].tags) != 1 { |
|||
t.Fatal("expected a tag") |
|||
} |
|||
if cm.segments.get(rp).peers[rp].value != 1 { |
|||
t.Fatal("expected a tag value of 1") |
|||
} |
|||
} |
|||
|
|||
func TestTemporaryEntriesClearedFirst(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
|
|||
id := tu.RandPeerIDFatal(t) |
|||
cm.TagPeer(id, "test", 20) |
|||
|
|||
if cm.GetTagInfo(id).Value != 20 { |
|||
t.Fatal("expected an early tag with value 20") |
|||
} |
|||
|
|||
not := cm.Notifee() |
|||
conn1, conn2 := randConn(t, nil), randConn(t, nil) |
|||
not.Connected(nil, conn1) |
|||
not.Connected(nil, conn2) |
|||
|
|||
cm.TrimOpenConns(context.Background()) |
|||
if cm.GetTagInfo(id) != nil { |
|||
t.Fatal("expected no temporary tags after trimming") |
|||
} |
|||
} |
|||
|
|||
func TestTemporaryEntryConvertedOnConnection(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(0)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
conn := randConn(t, nil) |
|||
cm.TagPeer(conn.RemotePeer(), "test", 20) |
|||
|
|||
ti := cm.segments.get(conn.RemotePeer()).peers[conn.RemotePeer()] |
|||
|
|||
if ti.value != 20 || !ti.temp { |
|||
t.Fatal("expected a temporary tag with value 20") |
|||
} |
|||
|
|||
not := cm.Notifee() |
|||
not.Connected(nil, conn) |
|||
|
|||
if ti.value != 20 || ti.temp { |
|||
t.Fatal("expected a non-temporary tag with value 20") |
|||
} |
|||
} |
|||
|
|||
// see https://github.com/libp2p/go-libp2p-connmgr/issues/82
|
|||
func TestConcurrentCleanupAndTagging(t *testing.T) { |
|||
cm, err := NewConnManager(1, 1, WithGracePeriod(0), WithSilencePeriod(time.Millisecond)) |
|||
require.NoError(t, err) |
|||
defer cm.Close() |
|||
|
|||
for i := 0; i < 1000; i++ { |
|||
conn := randConn(t, nil) |
|||
cm.TagPeer(conn.RemotePeer(), "test", 20) |
|||
} |
|||
} |
|||
|
|||
type mockConn struct { |
|||
stats network.ConnStats |
|||
} |
|||
|
|||
func (m mockConn) Close() error { panic("implement me") } |
|||
func (m mockConn) LocalPeer() peer.ID { panic("implement me") } |
|||
func (m mockConn) LocalPrivateKey() crypto.PrivKey { panic("implement me") } |
|||
func (m mockConn) RemotePeer() peer.ID { panic("implement me") } |
|||
func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") } |
|||
func (m mockConn) LocalMultiaddr() ma.Multiaddr { panic("implement me") } |
|||
func (m mockConn) RemoteMultiaddr() ma.Multiaddr { panic("implement me") } |
|||
func (m mockConn) Stat() network.ConnStats { return m.stats } |
|||
func (m mockConn) ID() string { panic("implement me") } |
|||
func (m mockConn) NewStream(ctx context.Context) (network.Stream, error) { panic("implement me") } |
|||
func (m mockConn) GetStreams() []network.Stream { panic("implement me") } |
|||
|
|||
func TestPeerInfoSorting(t *testing.T) { |
|||
t.Run("starts with temporary connections", func(t *testing.T) { |
|||
p1 := peerInfo{id: peer.ID("peer1")} |
|||
p2 := peerInfo{id: peer.ID("peer2"), temp: true} |
|||
pis := peerInfos{p1, p2} |
|||
pis.SortByValue() |
|||
require.Equal(t, pis, peerInfos{p2, p1}) |
|||
}) |
|||
|
|||
t.Run("starts with low-value connections", func(t *testing.T) { |
|||
p1 := peerInfo{id: peer.ID("peer1"), value: 40} |
|||
p2 := peerInfo{id: peer.ID("peer2"), value: 20} |
|||
pis := peerInfos{p1, p2} |
|||
pis.SortByValue() |
|||
require.Equal(t, pis, peerInfos{p2, p1}) |
|||
}) |
|||
|
|||
t.Run("in a memory emergency, starts with incoming connections", func(t *testing.T) { |
|||
incoming := network.ConnStats{} |
|||
incoming.Direction = network.DirInbound |
|||
outgoing := network.ConnStats{} |
|||
outgoing.Direction = network.DirOutbound |
|||
p1 := peerInfo{ |
|||
id: peer.ID("peer1"), |
|||
conns: map[network.Conn]time.Time{ |
|||
&mockConn{stats: outgoing}: time.Now(), |
|||
}, |
|||
} |
|||
p2 := peerInfo{ |
|||
id: peer.ID("peer2"), |
|||
conns: map[network.Conn]time.Time{ |
|||
&mockConn{stats: outgoing}: time.Now(), |
|||
&mockConn{stats: incoming}: time.Now(), |
|||
}, |
|||
} |
|||
pis := peerInfos{p1, p2} |
|||
pis.SortByValueAndStreams() |
|||
require.Equal(t, pis, peerInfos{p2, p1}) |
|||
}) |
|||
|
|||
t.Run("in a memory emergency, starts with connections that have many streams", func(t *testing.T) { |
|||
p1 := peerInfo{ |
|||
id: peer.ID("peer1"), |
|||
conns: map[network.Conn]time.Time{ |
|||
&mockConn{stats: network.ConnStats{NumStreams: 100}}: time.Now(), |
|||
}, |
|||
} |
|||
p2 := peerInfo{ |
|||
id: peer.ID("peer2"), |
|||
conns: map[network.Conn]time.Time{ |
|||
&mockConn{stats: network.ConnStats{NumStreams: 80}}: time.Now(), |
|||
&mockConn{stats: network.ConnStats{NumStreams: 40}}: time.Now(), |
|||
}, |
|||
} |
|||
pis := peerInfos{p1, p2} |
|||
pis.SortByValueAndStreams() |
|||
require.Equal(t, pis, peerInfos{p2, p1}) |
|||
}) |
|||
} |
@ -0,0 +1,355 @@ |
|||
package connmgr |
|||
|
|||
import ( |
|||
"fmt" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/connmgr" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
|
|||
"github.com/benbjohnson/clock" |
|||
) |
|||
|
|||
// DefaultResolution is the default resolution of the decay tracker.
|
|||
var DefaultResolution = 1 * time.Minute |
|||
|
|||
// bumpCmd represents a bump command.
|
|||
type bumpCmd struct { |
|||
peer peer.ID |
|||
tag *decayingTag |
|||
delta int |
|||
} |
|||
|
|||
// removeCmd represents a tag removal command.
|
|||
type removeCmd struct { |
|||
peer peer.ID |
|||
tag *decayingTag |
|||
} |
|||
|
|||
// decayer tracks and manages all decaying tags and their values.
|
|||
type decayer struct { |
|||
cfg *DecayerCfg |
|||
mgr *BasicConnMgr |
|||
clock clock.Clock // for testing.
|
|||
|
|||
tagsMu sync.Mutex |
|||
knownTags map[string]*decayingTag |
|||
|
|||
// lastTick stores the last time the decayer ticked. Guarded by atomic.
|
|||
lastTick atomic.Value |
|||
|
|||
// bumpTagCh queues bump commands to be processed by the loop.
|
|||
bumpTagCh chan bumpCmd |
|||
removeTagCh chan removeCmd |
|||
closeTagCh chan *decayingTag |
|||
|
|||
// closure thingies.
|
|||
closeCh chan struct{} |
|||
doneCh chan struct{} |
|||
err error |
|||
} |
|||
|
|||
var _ connmgr.Decayer = (*decayer)(nil) |
|||
|
|||
// DecayerCfg is the configuration object for the Decayer.
|
|||
type DecayerCfg struct { |
|||
Resolution time.Duration |
|||
Clock clock.Clock |
|||
} |
|||
|
|||
// WithDefaults writes the default values on this DecayerConfig instance,
|
|||
// and returns itself for chainability.
|
|||
//
|
|||
// cfg := (&DecayerCfg{}).WithDefaults()
|
|||
// cfg.Resolution = 30 * time.Second
|
|||
// t := NewDecayer(cfg, cm)
|
|||
func (cfg *DecayerCfg) WithDefaults() *DecayerCfg { |
|||
cfg.Resolution = DefaultResolution |
|||
return cfg |
|||
} |
|||
|
|||
// NewDecayer creates a new decaying tag registry.
|
|||
func NewDecayer(cfg *DecayerCfg, mgr *BasicConnMgr) (*decayer, error) { |
|||
// use real time if the Clock in the config is nil.
|
|||
if cfg.Clock == nil { |
|||
cfg.Clock = clock.New() |
|||
} |
|||
|
|||
d := &decayer{ |
|||
cfg: cfg, |
|||
mgr: mgr, |
|||
clock: cfg.Clock, |
|||
knownTags: make(map[string]*decayingTag), |
|||
bumpTagCh: make(chan bumpCmd, 128), |
|||
removeTagCh: make(chan removeCmd, 128), |
|||
closeTagCh: make(chan *decayingTag, 128), |
|||
closeCh: make(chan struct{}), |
|||
doneCh: make(chan struct{}), |
|||
} |
|||
|
|||
d.lastTick.Store(d.clock.Now()) |
|||
|
|||
// kick things off.
|
|||
go d.process() |
|||
|
|||
return d, nil |
|||
} |
|||
|
|||
func (d *decayer) RegisterDecayingTag(name string, interval time.Duration, decayFn connmgr.DecayFn, bumpFn connmgr.BumpFn) (connmgr.DecayingTag, error) { |
|||
d.tagsMu.Lock() |
|||
defer d.tagsMu.Unlock() |
|||
|
|||
if _, ok := d.knownTags[name]; ok { |
|||
return nil, fmt.Errorf("decaying tag with name %s already exists", name) |
|||
} |
|||
|
|||
if interval < d.cfg.Resolution { |
|||
log.Warnf("decay interval for %s (%s) was lower than tracker's resolution (%s); overridden to resolution", |
|||
name, interval, d.cfg.Resolution) |
|||
interval = d.cfg.Resolution |
|||
} |
|||
|
|||
if interval%d.cfg.Resolution != 0 { |
|||
log.Warnf("decay interval for tag %s (%s) is not a multiple of tracker's resolution (%s); "+ |
|||
"some precision may be lost", name, interval, d.cfg.Resolution) |
|||
} |
|||
|
|||
lastTick := d.lastTick.Load().(time.Time) |
|||
tag := &decayingTag{ |
|||
trkr: d, |
|||
name: name, |
|||
interval: interval, |
|||
nextTick: lastTick.Add(interval), |
|||
decayFn: decayFn, |
|||
bumpFn: bumpFn, |
|||
} |
|||
|
|||
d.knownTags[name] = tag |
|||
return tag, nil |
|||
} |
|||
|
|||
// Close closes the Decayer. It is idempotent.
|
|||
func (d *decayer) Close() error { |
|||
select { |
|||
case <-d.doneCh: |
|||
return d.err |
|||
default: |
|||
} |
|||
|
|||
close(d.closeCh) |
|||
<-d.doneCh |
|||
return d.err |
|||
} |
|||
|
|||
// process is the heart of the tracker. It performs the following duties:
|
|||
//
|
|||
// 1. Manages decay.
|
|||
// 2. Applies score bumps.
|
|||
// 3. Yields when closed.
|
|||
func (d *decayer) process() { |
|||
defer close(d.doneCh) |
|||
|
|||
ticker := d.clock.Ticker(d.cfg.Resolution) |
|||
defer ticker.Stop() |
|||
|
|||
var ( |
|||
bmp bumpCmd |
|||
now time.Time |
|||
visit = make(map[*decayingTag]struct{}) |
|||
) |
|||
|
|||
for { |
|||
select { |
|||
case now = <-ticker.C: |
|||
d.lastTick.Store(now) |
|||
|
|||
d.tagsMu.Lock() |
|||
for _, tag := range d.knownTags { |
|||
if tag.nextTick.After(now) { |
|||
// skip the tag.
|
|||
continue |
|||
} |
|||
// Mark the tag to be updated in this round.
|
|||
visit[tag] = struct{}{} |
|||
} |
|||
d.tagsMu.Unlock() |
|||
|
|||
// Visit each peer, and decay tags that need to be decayed.
|
|||
for _, s := range d.mgr.segments { |
|||
s.Lock() |
|||
|
|||
// Entered a segment that contains peers. Process each peer.
|
|||
for _, p := range s.peers { |
|||
for tag, v := range p.decaying { |
|||
if _, ok := visit[tag]; !ok { |
|||
// skip this tag.
|
|||
continue |
|||
} |
|||
|
|||
// ~ this value needs to be visited. ~
|
|||
var delta int |
|||
if after, rm := tag.decayFn(*v); rm { |
|||
// delete the value and move on to the next tag.
|
|||
delta -= v.Value |
|||
delete(p.decaying, tag) |
|||
} else { |
|||
// accumulate the delta, and apply the changes.
|
|||
delta += after - v.Value |
|||
v.Value, v.LastVisit = after, now |
|||
} |
|||
p.value += delta |
|||
} |
|||
} |
|||
|
|||
s.Unlock() |
|||
} |
|||
|
|||
// Reset each tag's next visit round, and clear the visited set.
|
|||
for tag := range visit { |
|||
tag.nextTick = tag.nextTick.Add(tag.interval) |
|||
delete(visit, tag) |
|||
} |
|||
|
|||
case bmp = <-d.bumpTagCh: |
|||
var ( |
|||
now = d.clock.Now() |
|||
peer, tag = bmp.peer, bmp.tag |
|||
) |
|||
|
|||
s := d.mgr.segments.get(peer) |
|||
s.Lock() |
|||
|
|||
p := s.tagInfoFor(peer) |
|||
v, ok := p.decaying[tag] |
|||
if !ok { |
|||
v = &connmgr.DecayingValue{ |
|||
Tag: tag, |
|||
Peer: peer, |
|||
LastVisit: now, |
|||
Added: now, |
|||
Value: 0, |
|||
} |
|||
p.decaying[tag] = v |
|||
} |
|||
|
|||
prev := v.Value |
|||
v.Value, v.LastVisit = v.Tag.(*decayingTag).bumpFn(*v, bmp.delta), now |
|||
p.value += v.Value - prev |
|||
|
|||
s.Unlock() |
|||
|
|||
case rm := <-d.removeTagCh: |
|||
s := d.mgr.segments.get(rm.peer) |
|||
s.Lock() |
|||
|
|||
p := s.tagInfoFor(rm.peer) |
|||
v, ok := p.decaying[rm.tag] |
|||
if !ok { |
|||
s.Unlock() |
|||
continue |
|||
} |
|||
p.value -= v.Value |
|||
delete(p.decaying, rm.tag) |
|||
s.Unlock() |
|||
|
|||
case t := <-d.closeTagCh: |
|||
// Stop tracking the tag.
|
|||
d.tagsMu.Lock() |
|||
delete(d.knownTags, t.name) |
|||
d.tagsMu.Unlock() |
|||
|
|||
// Remove the tag from all peers that had it in the connmgr.
|
|||
for _, s := range d.mgr.segments { |
|||
// visit all segments, and attempt to remove the tag from all the peers it stores.
|
|||
s.Lock() |
|||
for _, p := range s.peers { |
|||
if dt, ok := p.decaying[t]; ok { |
|||
// decrease the value of the tagInfo, and delete the tag.
|
|||
p.value -= dt.Value |
|||
delete(p.decaying, t) |
|||
} |
|||
} |
|||
s.Unlock() |
|||
} |
|||
|
|||
case <-d.closeCh: |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
// decayingTag represents a decaying tag, with an associated decay interval, a
|
|||
// decay function, and a bump function.
|
|||
type decayingTag struct { |
|||
trkr *decayer |
|||
name string |
|||
interval time.Duration |
|||
nextTick time.Time |
|||
decayFn connmgr.DecayFn |
|||
bumpFn connmgr.BumpFn |
|||
|
|||
// closed marks this tag as closed, so that if it's bumped after being
|
|||
// closed, we can return an error. 0 = false; 1 = true; guarded by atomic.
|
|||
closed int32 |
|||
} |
|||
|
|||
var _ connmgr.DecayingTag = (*decayingTag)(nil) |
|||
|
|||
func (t *decayingTag) Name() string { |
|||
return t.name |
|||
} |
|||
|
|||
func (t *decayingTag) Interval() time.Duration { |
|||
return t.interval |
|||
} |
|||
|
|||
// Bump bumps a tag for this peer.
|
|||
func (t *decayingTag) Bump(p peer.ID, delta int) error { |
|||
if atomic.LoadInt32(&t.closed) == 1 { |
|||
return fmt.Errorf("decaying tag %s had been closed; no further bumps are accepted", t.name) |
|||
} |
|||
|
|||
bmp := bumpCmd{peer: p, tag: t, delta: delta} |
|||
|
|||
select { |
|||
case t.trkr.bumpTagCh <- bmp: |
|||
return nil |
|||
default: |
|||
return fmt.Errorf( |
|||
"unable to bump decaying tag for peer %s, tag %s, delta %d; queue full (len=%d)", |
|||
p.Pretty(), t.name, delta, len(t.trkr.bumpTagCh)) |
|||
} |
|||
} |
|||
|
|||
func (t *decayingTag) Remove(p peer.ID) error { |
|||
if atomic.LoadInt32(&t.closed) == 1 { |
|||
return fmt.Errorf("decaying tag %s had been closed; no further removals are accepted", t.name) |
|||
} |
|||
|
|||
rm := removeCmd{peer: p, tag: t} |
|||
|
|||
select { |
|||
case t.trkr.removeTagCh <- rm: |
|||
return nil |
|||
default: |
|||
return fmt.Errorf( |
|||
"unable to remove decaying tag for peer %s, tag %s; queue full (len=%d)", |
|||
p.Pretty(), t.name, len(t.trkr.removeTagCh)) |
|||
} |
|||
} |
|||
|
|||
func (t *decayingTag) Close() error { |
|||
if !atomic.CompareAndSwapInt32(&t.closed, 0, 1) { |
|||
log.Warnf("duplicate decaying tag closure: %s; skipping", t.name) |
|||
return nil |
|||
} |
|||
|
|||
select { |
|||
case t.trkr.closeTagCh <- t: |
|||
return nil |
|||
default: |
|||
return fmt.Errorf("unable to close decaying tag %s; queue full (len=%d)", t.name, len(t.trkr.closeTagCh)) |
|||
} |
|||
} |
@ -0,0 +1,421 @@ |
|||
package connmgr |
|||
|
|||
import ( |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p-core/connmgr" |
|||
"github.com/libp2p/go-libp2p-core/peer" |
|||
tu "github.com/libp2p/go-libp2p-core/test" |
|||
"github.com/stretchr/testify/require" |
|||
|
|||
"github.com/benbjohnson/clock" |
|||
) |
|||
|
|||
const TestResolution = 50 * time.Millisecond |
|||
|
|||
func TestDecayExpire(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
) |
|||
|
|||
tag, err := decay.RegisterDecayingTag("pop", 250*time.Millisecond, connmgr.DecayExpireWhenInactive(1*time.Second), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
err = tag.Bump(id, 10) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
// give time for the bump command to process.
|
|||
<-time.After(100 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 10 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 10, v) |
|||
} |
|||
|
|||
mockClock.Add(250 * time.Millisecond) |
|||
mockClock.Add(250 * time.Millisecond) |
|||
mockClock.Add(250 * time.Millisecond) |
|||
mockClock.Add(250 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 0 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 0, v) |
|||
} |
|||
} |
|||
|
|||
func TestMultipleBumps(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, _ = testDecayTracker(t) |
|||
) |
|||
|
|||
tag, err := decay.RegisterDecayingTag("pop", 250*time.Millisecond, connmgr.DecayExpireWhenInactive(1*time.Second), connmgr.BumpSumBounded(10, 20)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
err = tag.Bump(id, 5) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
<-time.After(100 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 10 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 10, v) |
|||
} |
|||
|
|||
err = tag.Bump(id, 100) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
<-time.After(100 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 20 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 20, v) |
|||
} |
|||
} |
|||
|
|||
func TestMultipleTagsNoDecay(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, _ = testDecayTracker(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayNone(), connmgr.BumpSumBounded(0, 100)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tag2, err := decay.RegisterDecayingTag("bop", 250*time.Millisecond, connmgr.DecayNone(), connmgr.BumpSumBounded(0, 100)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tag3, err := decay.RegisterDecayingTag("foo", 250*time.Millisecond, connmgr.DecayNone(), connmgr.BumpSumBounded(0, 100)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
_ = tag1.Bump(id, 100) |
|||
_ = tag2.Bump(id, 100) |
|||
_ = tag3.Bump(id, 100) |
|||
_ = tag1.Bump(id, 100) |
|||
_ = tag2.Bump(id, 100) |
|||
_ = tag3.Bump(id, 100) |
|||
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
// all tags are upper-bounded, so the score must be 300
|
|||
ti := mgr.GetTagInfo(id) |
|||
if v := ti.Value; v != 300 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 300, v) |
|||
} |
|||
|
|||
for _, s := range []string{"beep", "bop", "foo"} { |
|||
if v, ok := ti.Tags[s]; !ok || v != 100 { |
|||
t.Fatalf("expected tag %s to be 100; was = %d", s, v) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestCustomFunctions(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayFixed(10), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tag2, err := decay.RegisterDecayingTag("bop", 100*time.Millisecond, connmgr.DecayFixed(5), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tag3, err := decay.RegisterDecayingTag("foo", 50*time.Millisecond, connmgr.DecayFixed(1), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
_ = tag1.Bump(id, 1000) |
|||
_ = tag2.Bump(id, 1000) |
|||
_ = tag3.Bump(id, 1000) |
|||
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
// no decay has occurred yet, so score must be 3000.
|
|||
if v := mgr.GetTagInfo(id).Value; v != 3000 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 3000, v) |
|||
} |
|||
|
|||
// only tag3 should tick.
|
|||
mockClock.Add(50 * time.Millisecond) |
|||
if v := mgr.GetTagInfo(id).Value; v != 2999 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 2999, v) |
|||
} |
|||
|
|||
// tag3 will tick thrice, tag2 will tick twice.
|
|||
mockClock.Add(150 * time.Millisecond) |
|||
if v := mgr.GetTagInfo(id).Value; v != 2986 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 2986, v) |
|||
} |
|||
|
|||
// tag3 will tick once, tag1 will tick once.
|
|||
mockClock.Add(50 * time.Millisecond) |
|||
if v := mgr.GetTagInfo(id).Value; v != 2975 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 2975, v) |
|||
} |
|||
} |
|||
|
|||
func TestMultiplePeers(t *testing.T) { |
|||
var ( |
|||
ids = []peer.ID{tu.RandPeerIDFatal(t), tu.RandPeerIDFatal(t), tu.RandPeerIDFatal(t)} |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayFixed(10), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tag2, err := decay.RegisterDecayingTag("bop", 100*time.Millisecond, connmgr.DecayFixed(5), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
tag3, err := decay.RegisterDecayingTag("foo", 50*time.Millisecond, connmgr.DecayFixed(1), connmgr.BumpSumUnbounded()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
_ = tag1.Bump(ids[0], 1000) |
|||
_ = tag2.Bump(ids[0], 1000) |
|||
_ = tag3.Bump(ids[0], 1000) |
|||
|
|||
_ = tag1.Bump(ids[1], 500) |
|||
_ = tag2.Bump(ids[1], 500) |
|||
_ = tag3.Bump(ids[1], 500) |
|||
|
|||
_ = tag1.Bump(ids[2], 100) |
|||
_ = tag2.Bump(ids[2], 100) |
|||
_ = tag3.Bump(ids[2], 100) |
|||
|
|||
// allow the background goroutine to process bumps.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
mockClock.Add(3 * time.Second) |
|||
|
|||
// allow the background goroutine to process ticks.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(ids[0]).Value; v != 2670 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 2670, v) |
|||
} |
|||
|
|||
if v := mgr.GetTagInfo(ids[1]).Value; v != 1170 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 1170, v) |
|||
} |
|||
|
|||
if v := mgr.GetTagInfo(ids[2]).Value; v != 40 { |
|||
t.Fatalf("wrong value; expected = %d; got = %d", 40, v) |
|||
} |
|||
} |
|||
|
|||
func TestLinearDecayOverwrite(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", 250*time.Millisecond, connmgr.DecayLinear(0.5), connmgr.BumpOverwrite()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
_ = tag1.Bump(id, 1000) |
|||
// allow the background goroutine to process bumps.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
mockClock.Add(250 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 500 { |
|||
t.Fatalf("value should be half; got = %d", v) |
|||
} |
|||
|
|||
mockClock.Add(250 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 250 { |
|||
t.Fatalf("value should be half; got = %d", v) |
|||
} |
|||
|
|||
_ = tag1.Bump(id, 1000) |
|||
// allow the background goroutine to process bumps.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
if v := mgr.GetTagInfo(id).Value; v != 1000 { |
|||
t.Fatalf("value should 1000; got = %d", v) |
|||
} |
|||
} |
|||
|
|||
func TestResolutionMisaligned(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
require = require.New(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", time.Duration(float64(TestResolution)*1.4), connmgr.DecayFixed(1), connmgr.BumpOverwrite()) |
|||
require.NoError(err) |
|||
|
|||
tag2, err := decay.RegisterDecayingTag("bop", time.Duration(float64(TestResolution)*2.4), connmgr.DecayFixed(1), connmgr.BumpOverwrite()) |
|||
require.NoError(err) |
|||
|
|||
_ = tag1.Bump(id, 1000) |
|||
_ = tag2.Bump(id, 1000) |
|||
// allow the background goroutine to process bumps.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
// first tick.
|
|||
mockClock.Add(TestResolution) |
|||
require.Equal(1000, mgr.GetTagInfo(id).Tags["beep"]) |
|||
require.Equal(1000, mgr.GetTagInfo(id).Tags["bop"]) |
|||
|
|||
// next tick; tag1 would've ticked.
|
|||
mockClock.Add(TestResolution) |
|||
require.Equal(999, mgr.GetTagInfo(id).Tags["beep"]) |
|||
require.Equal(1000, mgr.GetTagInfo(id).Tags["bop"]) |
|||
|
|||
// next tick; tag1 would've ticked twice, tag2 once.
|
|||
mockClock.Add(TestResolution) |
|||
require.Equal(998, mgr.GetTagInfo(id).Tags["beep"]) |
|||
require.Equal(999, mgr.GetTagInfo(id).Tags["bop"]) |
|||
|
|||
require.Equal(1997, mgr.GetTagInfo(id).Value) |
|||
} |
|||
|
|||
func TestTagRemoval(t *testing.T) { |
|||
var ( |
|||
id1, id2 = tu.RandPeerIDFatal(t), tu.RandPeerIDFatal(t) |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
require = require.New(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) |
|||
require.NoError(err) |
|||
|
|||
tag2, err := decay.RegisterDecayingTag("bop", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) |
|||
require.NoError(err) |
|||
|
|||
// id1 has both tags; id2 only has the first tag.
|
|||
_ = tag1.Bump(id1, 1000) |
|||
_ = tag2.Bump(id1, 1000) |
|||
_ = tag1.Bump(id2, 1000) |
|||
|
|||
// allow the background goroutine to process bumps.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
// first tick.
|
|||
mockClock.Add(TestResolution) |
|||
require.Equal(999, mgr.GetTagInfo(id1).Tags["beep"]) |
|||
require.Equal(999, mgr.GetTagInfo(id1).Tags["bop"]) |
|||
require.Equal(999, mgr.GetTagInfo(id2).Tags["beep"]) |
|||
|
|||
require.Equal(999*2, mgr.GetTagInfo(id1).Value) |
|||
require.Equal(999, mgr.GetTagInfo(id2).Value) |
|||
|
|||
// remove tag1 from p1.
|
|||
err = tag1.Remove(id1) |
|||
|
|||
// allow the background goroutine to process the removal.
|
|||
<-time.After(500 * time.Millisecond) |
|||
require.NoError(err) |
|||
|
|||
// next tick. both peers only have 1 tag, both at 998 value.
|
|||
mockClock.Add(TestResolution) |
|||
require.Zero(mgr.GetTagInfo(id1).Tags["beep"]) |
|||
require.Equal(998, mgr.GetTagInfo(id1).Tags["bop"]) |
|||
require.Equal(998, mgr.GetTagInfo(id2).Tags["beep"]) |
|||
|
|||
require.Equal(998, mgr.GetTagInfo(id1).Value) |
|||
require.Equal(998, mgr.GetTagInfo(id2).Value) |
|||
|
|||
// remove tag1 from p1 again; no error.
|
|||
err = tag1.Remove(id1) |
|||
require.NoError(err) |
|||
} |
|||
|
|||
func TestTagClosure(t *testing.T) { |
|||
var ( |
|||
id = tu.RandPeerIDFatal(t) |
|||
mgr, decay, mockClock = testDecayTracker(t) |
|||
require = require.New(t) |
|||
) |
|||
|
|||
tag1, err := decay.RegisterDecayingTag("beep", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) |
|||
require.NoError(err) |
|||
|
|||
tag2, err := decay.RegisterDecayingTag("bop", TestResolution, connmgr.DecayFixed(1), connmgr.BumpOverwrite()) |
|||
require.NoError(err) |
|||
|
|||
_ = tag1.Bump(id, 1000) |
|||
_ = tag2.Bump(id, 1000) |
|||
// allow the background goroutine to process bumps.
|
|||
<-time.After(500 * time.Millisecond) |
|||
|
|||
// nothing has happened.
|
|||
mockClock.Add(TestResolution) |
|||
require.Equal(999, mgr.GetTagInfo(id).Tags["beep"]) |
|||
require.Equal(999, mgr.GetTagInfo(id).Tags["bop"]) |
|||
require.Equal(999*2, mgr.GetTagInfo(id).Value) |
|||
|
|||
// next tick; tag1 would've ticked.
|
|||
mockClock.Add(TestResolution) |
|||
require.Equal(998, mgr.GetTagInfo(id).Tags["beep"]) |
|||
require.Equal(998, mgr.GetTagInfo(id).Tags["bop"]) |
|||
require.Equal(998*2, mgr.GetTagInfo(id).Value) |
|||
|
|||
// close the tag.
|
|||
err = tag1.Close() |
|||
require.NoError(err) |
|||
|
|||
// allow the background goroutine to process the closure.
|
|||
<-time.After(500 * time.Millisecond) |
|||
require.Equal(998, mgr.GetTagInfo(id).Value) |
|||
|
|||
// a second closure should not error.
|
|||
err = tag1.Close() |
|||
require.NoError(err) |
|||
|
|||
// bumping a tag after it's been closed should error.
|
|||
err = tag1.Bump(id, 5) |
|||
require.Error(err) |
|||
} |
|||
|
|||
func testDecayTracker(tb testing.TB) (*BasicConnMgr, connmgr.Decayer, *clock.Mock) { |
|||
mockClock := clock.NewMock() |
|||
cfg := &DecayerCfg{ |
|||
Resolution: TestResolution, |
|||
Clock: mockClock, |
|||
} |
|||
|
|||
mgr, err := NewConnManager(10, 10, WithGracePeriod(time.Second), DecayerConfig(cfg)) |
|||
require.NoError(tb, err) |
|||
decay, ok := connmgr.SupportsDecay(mgr) |
|||
if !ok { |
|||
tb.Fatalf("connmgr does not support decay") |
|||
} |
|||
tb.Cleanup(func() { |
|||
mgr.Close() |
|||
decay.Close() |
|||
}) |
|||
|
|||
return mgr, decay, mockClock |
|||
} |
@ -0,0 +1,53 @@ |
|||
package connmgr |
|||
|
|||
import ( |
|||
"errors" |
|||
"time" |
|||
) |
|||
|
|||
// config is the configuration struct for the basic connection manager.
|
|||
type config struct { |
|||
highWater int |
|||
lowWater int |
|||
gracePeriod time.Duration |
|||
silencePeriod time.Duration |
|||
decayer *DecayerCfg |
|||
emergencyTrim bool |
|||
} |
|||
|
|||
// Option represents an option for the basic connection manager.
|
|||
type Option func(*config) error |
|||
|
|||
// DecayerConfig applies a configuration for the decayer.
|
|||
func DecayerConfig(opts *DecayerCfg) Option { |
|||
return func(cfg *config) error { |
|||
cfg.decayer = opts |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// WithGracePeriod sets the grace period.
|
|||
// The grace period is the time a newly opened connection is given before it becomes
|
|||
// subject to pruning.
|
|||
func WithGracePeriod(p time.Duration) Option { |
|||
return func(cfg *config) error { |
|||
if p < 0 { |
|||
return errors.New("grace period must be non-negative") |
|||
} |
|||
cfg.gracePeriod = p |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// WithSilencePeriod sets the silence period.
|
|||
// The connection manager will perform a cleanup once per silence period
|
|||
// if the number of connections surpasses the high watermark.
|
|||
func WithSilencePeriod(p time.Duration) Option { |
|||
return func(cfg *config) error { |
|||
if p <= 0 { |
|||
return errors.New("silence period must be non-zero") |
|||
} |
|||
cfg.silencePeriod = p |
|||
return nil |
|||
} |
|||
} |
@ -0,0 +1,18 @@ |
|||
//go:build cgo
|
|||
// +build cgo
|
|||
|
|||
package connmgr |
|||
|
|||
import "github.com/raulk/go-watchdog" |
|||
|
|||
func registerWatchdog(cb func()) (unregister func()) { |
|||
return watchdog.RegisterPostGCNotifee(cb) |
|||
} |
|||
|
|||
// WithEmergencyTrim is an option to enable trimming connections on memory emergency.
|
|||
func WithEmergencyTrim(enable bool) Option { |
|||
return func(cfg *config) error { |
|||
cfg.emergencyTrim = enable |
|||
return nil |
|||
} |
|||
} |
@ -0,0 +1,16 @@ |
|||
//go:build !cgo
|
|||
// +build !cgo
|
|||
|
|||
package connmgr |
|||
|
|||
func registerWatchdog(func()) (unregister func()) { |
|||
return nil |
|||
} |
|||
|
|||
// WithEmergencyTrim is an option to enable trimming connections on memory emergency.
|
|||
func WithEmergencyTrim(enable bool) Option { |
|||
return func(cfg *config) error { |
|||
log.Warn("platform doesn't support go-watchdog") |
|||
return nil |
|||
} |
|||
} |
Loading…
Reference in new issue