diff --git a/p2p/protocol/internal/circuitv1-deprecated/conn.go b/p2p/protocol/internal/circuitv1-deprecated/conn.go index 11f947633..1e46c4341 100644 --- a/p2p/protocol/internal/circuitv1-deprecated/conn.go +++ b/p2p/protocol/internal/circuitv1-deprecated/conn.go @@ -13,10 +13,14 @@ import ( manet "github.com/multiformats/go-multiaddr-net" ) +// HopTagWeight is the connection manager weight for connections carrying relay hop streams +var HopTagWeight = 5 + type Conn struct { stream network.Stream remote peer.AddrInfo host host.Host + relay *Relay } type NetAddr struct { @@ -69,13 +73,28 @@ func (c *Conn) RemoteAddr() net.Addr { // by the connection manager, taking with them all the relayed connections (that may themselves // be protected). func (c *Conn) tagHop() { - c.host.ConnManager().UpsertTag(c.stream.Conn().RemotePeer(), "relay-hop-stream", incrementTag) + c.relay.mx.Lock() + defer c.relay.mx.Unlock() + + p := c.stream.Conn().RemotePeer() + c.relay.hopCount[p]++ + if c.relay.hopCount[p] == 1 { + c.host.ConnManager().TagPeer(p, "relay-hop-stream", HopTagWeight) + } } // Decrement the underlying relay connection tag by 1; this is performed when we close the // relayed connection. func (c *Conn) untagHop() { - c.host.ConnManager().UpsertTag(c.stream.Conn().RemotePeer(), "relay-hop-stream", decrementTag) + c.relay.mx.Lock() + defer c.relay.mx.Unlock() + + p := c.stream.Conn().RemotePeer() + c.relay.hopCount[p]-- + if c.relay.hopCount[p] == 0 { + c.host.ConnManager().UntagPeer(p, "relay-hop-stream") + delete(c.relay.hopCount, p) + } } // TODO: is it okay to cast c.Conn().RemotePeer() into a multiaddr? might be "user input" diff --git a/p2p/protocol/internal/circuitv1-deprecated/relay.go b/p2p/protocol/internal/circuitv1-deprecated/relay.go index 10816233b..ee7464106 100644 --- a/p2p/protocol/internal/circuitv1-deprecated/relay.go +++ b/p2p/protocol/internal/circuitv1-deprecated/relay.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "sync" "sync/atomic" "time" @@ -52,6 +53,10 @@ type Relay struct { // atomic counters streamCount int32 liveHopCount int32 + + // per peer hop counters + mx sync.Mutex + hopCount map[peer.ID]int } // RelayOpts are options for configuring the relay transport. @@ -94,6 +99,7 @@ func NewRelay(ctx context.Context, h host.Host, upgrader *tptu.Upgrader, opts .. ctx: ctx, self: h.ID(), incoming: make(chan *Conn), + hopCount: make(map[peer.ID]int), } for _, opt := range opts { @@ -186,7 +192,7 @@ func (r *Relay) DialPeer(ctx context.Context, relay peer.AddrInfo, dest peer.Add return nil, RelayError{msg.GetCode()} } - return &Conn{stream: s, remote: dest, host: r.host}, nil + return &Conn{stream: s, remote: dest, host: r.host, relay: r}, nil } func (r *Relay) Matches(addr ma.Multiaddr) bool { @@ -449,7 +455,7 @@ func (r *Relay) handleStopStream(s network.Stream, msg *pb.CircuitRelay) { } select { - case r.incoming <- &Conn{stream: s, remote: src, host: r.host}: + case r.incoming <- &Conn{stream: s, remote: src, host: r.host, relay: r}: case <-time.After(RelayAcceptTimeout): r.handleError(s, pb.CircuitRelay_STOP_RELAY_REFUSED) }