Browse Source

don't use a context to shut down the circuitv2

pull/1185/head
Marten Seemann 3 years ago
parent
commit
729386c495
  1. 14
      config/config.go
  2. 28
      p2p/protocol/circuitv2/client/client.go
  3. 6
      p2p/protocol/circuitv2/client/listen.go
  4. 6
      p2p/protocol/circuitv2/client/transport.go
  5. 4
      p2p/protocol/circuitv2/test/compat_test.go
  6. 23
      p2p/protocol/circuitv2/test/e2e_test.go

14
config/config.go

@ -137,7 +137,7 @@ func (cfg *Config) makeSwarm(ctx context.Context) (*swarm.Swarm, error) {
return swrm, nil
}
func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) {
func (cfg *Config) addTransports(h host.Host) (err error) {
swrm, ok := h.Network().(transport.TransportNetwork)
if !ok {
// Should probably skip this if no transports.
@ -165,15 +165,13 @@ func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) {
return err
}
for _, t := range tpts {
err = swrm.AddTransport(t)
if err != nil {
if err := swrm.AddTransport(t); err != nil {
return err
}
}
if cfg.Relay {
err := circuitv2.AddTransport(ctx, h, upgrader)
if err != nil {
if err := circuitv2.AddTransport(h, upgrader); err != nil {
h.Close()
return err
}
@ -225,8 +223,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) {
}
}
err = cfg.addTransports(ctx, h)
if err != nil {
if err := cfg.addTransports(h); err != nil {
h.Close()
return nil, err
}
@ -314,8 +311,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) {
return nil, err
}
dialerHost := blankhost.NewBlankHost(dialer)
err = autoNatCfg.addTransports(ctx, dialerHost)
if err != nil {
if err := autoNatCfg.addTransports(dialerHost); err != nil {
dialerHost.Close()
h.Close()
return nil, err

28
p2p/protocol/circuitv2/client/client.go

@ -2,12 +2,14 @@ package client
import (
"context"
"io"
"sync"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/transport"
logging "github.com/ipfs/go-log"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
@ -24,9 +26,10 @@ var log = logging.Logger("p2p-circuit")
// This allows us to use the v2 code as drop in replacement for v1 in a host without breaking
// existing code and interoperability with older nodes.
type Client struct {
ctx context.Context
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
ctxCancel context.CancelFunc
host host.Host
upgrader *tptu.Upgrader
incoming chan accept
@ -35,6 +38,9 @@ type Client struct {
hopCount map[peer.ID]int
}
var _ io.Closer = &Client{}
var _ transport.Transport = &Client{}
type accept struct {
conn *Conn
writeResponse func() error
@ -48,15 +54,16 @@ type completion struct {
// New constructs a new p2p-circuit/v2 client, attached to the given host and using the given
// upgrader to perform connection upgrades.
func New(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
return &Client{
ctx: ctx,
func New(h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
cl := &Client{
host: h,
upgrader: upgrader,
incoming: make(chan accept),
activeDials: make(map[peer.ID]*completion),
hopCount: make(map[peer.ID]int),
}, nil
}
cl.ctx, cl.ctxCancel = context.WithCancel(context.Background())
return cl, nil
}
// Start registers the circuit (client) protocol stream handlers
@ -64,3 +71,10 @@ func (c *Client) Start() {
c.host.SetStreamHandler(proto.ProtoIDv1, c.handleStreamV1)
c.host.SetStreamHandler(proto.ProtoIDv2Stop, c.handleStreamV2)
}
func (c *Client) Close() error {
c.ctxCancel()
c.host.RemoveStreamHandler(proto.ProtoIDv1)
c.host.RemoveStreamHandler(proto.ProtoIDv2Stop)
return nil
}

6
p2p/protocol/circuitv2/client/listen.go

@ -1,6 +1,7 @@
package client
import (
"errors"
"net"
ma "github.com/multiformats/go-multiaddr"
@ -32,7 +33,7 @@ func (l *Listener) Accept() (manet.Conn, error) {
return evt.conn, nil
case <-l.ctx.Done():
return nil, l.ctx.Err()
return nil, errors.New("circuit v2 client closed")
}
}
}
@ -49,6 +50,5 @@ func (l *Listener) Multiaddr() ma.Multiaddr {
}
func (l *Listener) Close() error {
// noop for now
return nil
return (*Client)(l).Close()
}

6
p2p/protocol/circuitv2/client/transport.go

@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"io"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
@ -17,13 +18,13 @@ var circuitAddr = ma.Cast(circuitProtocol.VCode)
// AddTransport constructs a new p2p-circuit/v2 client and adds it as a transport to the
// host network
func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) error {
func AddTransport(h host.Host, upgrader *tptu.Upgrader) error {
n, ok := h.Network().(transport.TransportNetwork)
if !ok {
return fmt.Errorf("%v is not a transport network", h.Network())
}
c, err := New(ctx, h, upgrader)
c, err := New(h, upgrader)
if err != nil {
return fmt.Errorf("error constructing circuit client: %w", err)
}
@ -45,6 +46,7 @@ func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) err
// Transport interface
var _ transport.Transport = (*Client)(nil)
var _ io.Closer = (*Client)(nil)
func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
conn, err := c.dial(ctx, a, p)

4
p2p/protocol/circuitv2/test/compat_test.go

@ -30,7 +30,7 @@ func TestRelayCompatV2DialV1(t *testing.T) {
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransportV1(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
@ -105,7 +105,7 @@ func TestRelayCompatV1DialV2(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, hosts[0], upgraders[0])
addTransportV1(t, ctx, hosts[2], upgraders[2])
rch := make(chan []byte, 1)

23
p2p/protocol/circuitv2/test/e2e_test.go

@ -20,12 +20,12 @@ import (
logging "github.com/ipfs/go-log"
bhost "github.com/libp2p/go-libp2p-blankhost"
metrics "github.com/libp2p/go-libp2p-core/metrics"
pstoremem "github.com/libp2p/go-libp2p-peerstore/pstoremem"
"github.com/libp2p/go-libp2p-core/metrics"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
tcp "github.com/libp2p/go-tcp-transport"
"github.com/libp2p/go-tcp-transport"
ma "github.com/multiformats/go-multiaddr"
)
@ -85,9 +85,8 @@ func connect(t *testing.T, a, b host.Host) {
}
}
func addTransport(t *testing.T, ctx context.Context, h host.Host, upgrader *tptu.Upgrader) {
err := client.AddTransport(ctx, h, upgrader)
if err != nil {
func addTransport(t *testing.T, h host.Host, upgrader *tptu.Upgrader) {
if err := client.AddTransport(h, upgrader); err != nil {
t.Fatal(err)
}
}
@ -97,8 +96,8 @@ func TestBasicRelay(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
@ -184,8 +183,8 @@ func TestRelayLimitTime(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan error, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
@ -258,8 +257,8 @@ func TestRelayLimitData(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan int, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {

Loading…
Cancel
Save