From cbe50095f327802f551408566c2442b746693177 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 8 Nov 2022 10:46:39 +0000 Subject: [PATCH] config: use fx to construct security transports --- config/config.go | 30 +++++--- config/constructor_types.go | 2 - config/security.go | 71 ++----------------- core/sec/insecure/insecure.go | 19 +++-- core/sec/insecure/insecure_test.go | 2 +- core/sec/security.go | 4 ++ libp2p_test.go | 41 +++++++++++ options.go | 26 +++++-- .../conn-security-multistream/ssms_test.go | 8 +-- p2p/net/swarm/dial_worker_test.go | 2 +- p2p/net/swarm/testing/testing.go | 2 +- p2p/net/upgrader/upgrader_test.go | 2 +- p2p/security/noise/benchmark_test.go | 2 +- p2p/security/noise/session_transport.go | 7 ++ p2p/security/noise/transport.go | 22 +++--- p2p/security/tls/cmd/tlsdiag/client.go | 2 +- p2p/security/tls/cmd/tlsdiag/server.go | 2 +- p2p/security/tls/transport.go | 26 ++++--- p2p/security/tls/transport_test.go | 24 +++---- p2p/transport/tcp/tcp_test.go | 2 +- p2p/transport/websocket/websocket_test.go | 4 +- p2p/transport/webtransport/transport.go | 2 +- 22 files changed, 165 insertions(+), 137 deletions(-) diff --git a/config/config.go b/config/config.go index 97578093f..d334c294e 100644 --- a/config/config.go +++ b/config/config.go @@ -13,8 +13,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" - "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/autorelay" @@ -74,7 +74,7 @@ type Config struct { Transports []fx.Option Muxers []MsMuxC - SecurityTransports []MsSecC + SecurityTransports []fx.Option Insecure bool PSK pnet.PSK @@ -169,15 +169,17 @@ func (cfg *Config) addTransports(h host.Host) error { // Should probably skip this if no transports. return fmt.Errorf("swarm does not support transports") } - var secure sec.SecureMuxer + + muxers := make([]protocol.ID, 0, len(cfg.Muxers)) + for _, muxc := range cfg.Muxers { + muxers = append(muxers, protocol.ID(muxc.ID)) + } + + var security []fx.Option if cfg.Insecure { - secure = makeInsecureTransport(h.ID(), cfg.PeerKey) + security = append(security, fx.Provide(makeInsecureTransport)) } else { - var err error - secure, err = makeSecurityMuxer(h, cfg.SecurityTransports, cfg.Muxers) - if err != nil { - return err - } + security = cfg.SecurityTransports } muxer, err := makeMuxer(h, cfg.Muxers) if err != nil { @@ -185,9 +187,14 @@ func (cfg *Config) addTransports(h host.Host) error { } fxopts := []fx.Option{ + fx.NopLogger, fx.Provide(tptu.New), - fx.Provide(func() sec.SecureMuxer { return secure }), fx.Provide(func() network.Multiplexer { return muxer }), + fx.Provide(fx.Annotate( + makeSecurityMuxer, + fx.ParamTags(`group:"security"`), + )), + fx.Supply(muxers), fx.Provide(func() host.Host { return h }), fx.Provide(func() crypto.PrivKey { return h.Peerstore().PrivKey(h.ID()) }), fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), @@ -196,6 +203,9 @@ func (cfg *Config) addTransports(h host.Host) error { fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), } fxopts = append(fxopts, cfg.Transports...) + if !cfg.Insecure { + fxopts = append(fxopts, security...) + } fxopts = append(fxopts, fx.Invoke( fx.Annotate( diff --git a/config/constructor_types.go b/config/constructor_types.go index 362f51ddf..a796914a7 100644 --- a/config/constructor_types.go +++ b/config/constructor_types.go @@ -12,7 +12,6 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" "github.com/libp2p/go-libp2p/core/protocol" - "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" madns "github.com/multiformats/go-multiaddr-dns" @@ -23,7 +22,6 @@ var ( hostType = reflect.TypeOf((*host.Host)(nil)).Elem() networkType = reflect.TypeOf((*network.Network)(nil)).Elem() muxType = reflect.TypeOf((*network.Multiplexer)(nil)).Elem() - securityType = reflect.TypeOf((*sec.SecureTransport)(nil)).Elem() privKeyType = reflect.TypeOf((*crypto.PrivKey)(nil)).Elem() pubKeyType = reflect.TypeOf((*crypto.PubKey)(nil)).Elem() pstoreType = reflect.TypeOf((*peerstore.Peerstore)(nil)).Elem() diff --git a/config/security.go b/config/security.go index 6f15b6346..1a88575f5 100644 --- a/config/security.go +++ b/config/security.go @@ -1,84 +1,23 @@ package config import ( - "fmt" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec/insecure" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" ) -// SecC is a security transport constructor. -type SecC func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) - -// MsSecC is a tuple containing a security transport constructor and a protocol -// ID. -type MsSecC struct { - SecC - ID string -} - -var securityArgTypes = newArgTypeSet( - hostType, networkType, peerIDType, - privKeyType, pubKeyType, pstoreType, - muxersType, -) - -// SecurityConstructor creates a security constructor from the passed parameter -// using reflection. -func SecurityConstructor(security interface{}) (SecC, error) { - // Already constructed? - if t, ok := security.(sec.SecureTransport); ok { - return func(_ host.Host, _ []protocol.ID) (sec.SecureTransport, error) { - return t, nil - }, nil - } - - ctor, err := makeConstructor(security, securityType, securityArgTypes) - if err != nil { - return nil, err - } - return func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) { - t, err := ctor(h, nil, nil, nil, nil, nil, muxers) - if err != nil { - return nil, err - } - return t.(sec.SecureTransport), nil - }, nil -} - func makeInsecureTransport(id peer.ID, privKey crypto.PrivKey) sec.SecureMuxer { secMuxer := new(csms.SSMuxer) - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, privKey)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, privKey)) return secMuxer } -func makeSecurityMuxer(h host.Host, tpts []MsSecC, muxers []MsMuxC) (sec.SecureMuxer, error) { +func makeSecurityMuxer(tpts []sec.SecureTransport) sec.SecureMuxer { secMuxer := new(csms.SSMuxer) - transportSet := make(map[string]struct{}, len(tpts)) - for _, tptC := range tpts { - if _, ok := transportSet[tptC.ID]; ok { - return nil, fmt.Errorf("duplicate security transport: %s", tptC.ID) - } - transportSet[tptC.ID] = struct{}{} + for _, tpt := range tpts { + secMuxer.AddTransport(string(tpt.ID()), tpt) } - muxIds := make([]protocol.ID, 0, len(muxers)) - for _, muxc := range muxers { - muxIds = append(muxIds, protocol.ID(muxc.ID)) - } - for _, tptC := range tpts { - tpt, err := tptC.SecC(h, muxIds) - if err != nil { - return nil, err - } - if _, ok := tpt.(*insecure.Transport); ok { - return nil, fmt.Errorf("cannot construct libp2p with an insecure transport, set the Insecure config option instead") - } - secMuxer.AddTransport(tptC.ID, tpt) - } - return secMuxer, nil + return secMuxer } diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go index 2d94f4380..d2487a3b0 100644 --- a/core/sec/insecure/insecure.go +++ b/core/sec/insecure/insecure.go @@ -12,6 +12,7 @@ import ( ci "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" pb "github.com/libp2p/go-libp2p/core/sec/insecure/pb" @@ -28,18 +29,22 @@ const ID = "/plaintext/2.0.0" // peer presents as their ID and public key. // No authentication of the remote identity is performed. type Transport struct { - id peer.ID - key ci.PrivKey + id peer.ID + key ci.PrivKey + protocolID protocol.ID } +var _ sec.SecureTransport = &Transport{} + // NewWithIdentity constructs a new insecure transport. The provided private key // is stored and returned from LocalPrivateKey to satisfy the // SecureTransport interface, and the public key is sent to // remote peers. No security is provided. -func NewWithIdentity(id peer.ID, key ci.PrivKey) *Transport { +func NewWithIdentity(protocolID protocol.ID, id peer.ID, key ci.PrivKey) *Transport { return &Transport{ - id: id, - key: key, + protocolID: protocolID, + id: id, + key: key, } } @@ -108,6 +113,10 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee return conn, nil } +func (t *Transport) ID() protocol.ID { + return t.protocolID +} + // Conn is the connection type returned by the insecure transport. type Conn struct { net.Conn diff --git a/core/sec/insecure/insecure_test.go b/core/sec/insecure/insecure_test.go index a3ce8314f..8663f1d97 100644 --- a/core/sec/insecure/insecure_test.go +++ b/core/sec/insecure/insecure_test.go @@ -61,7 +61,7 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport { require.NoError(t, err) id, err := peer.IDFromPublicKey(pub) require.NoError(t, err) - return NewWithIdentity(id, priv) + return NewWithIdentity("/test/1.0.0", id, priv) } // Create a new pair of connected TCP sockets. diff --git a/core/sec/security.go b/core/sec/security.go index c192a56a9..8b733b5d0 100644 --- a/core/sec/security.go +++ b/core/sec/security.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" ) // SecureConn is an authenticated, encrypted connection. @@ -24,6 +25,9 @@ type SecureTransport interface { // SecureOutbound secures an outbound connection. SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) + + // ID is the protocol ID of the security protocol. + ID() protocol.ID } // A SecureMuxer is a wrapper around SecureTransport which can select security protocols diff --git a/libp2p_test.go b/libp2p_test.go index 094fdefbc..8e70eef3c 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -12,6 +12,8 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/security/noise" + tls "github.com/libp2p/go-libp2p/p2p/security/tls" "github.com/libp2p/go-libp2p/p2p/transport/tcp" "github.com/stretchr/testify/require" @@ -165,3 +167,42 @@ func TestChainOptions(t *testing.T) { } } } + +func TestSecurityConstructor(t *testing.T) { + h, err := New( + Transport(tcp.NewTCPTransport), + Security("/noisy", noise.New), + Security("/tls", tls.New), + DefaultListenAddrs, + DisableRelay(), + ) + require.NoError(t, err) + defer h.Close() + + h1, err := New( + NoListenAddrs, + Transport(tcp.NewTCPTransport), + Security("/noise", noise.New), // different name + DisableRelay(), + ) + require.NoError(t, err) + defer h1.Close() + + h2, err := New( + NoListenAddrs, + Transport(tcp.NewTCPTransport), + Security("/noisy", noise.New), + DisableRelay(), + ) + require.NoError(t, err) + defer h2.Close() + + ai := peer.AddrInfo{ + ID: h.ID(), + Addrs: h.Addrs(), + } + err = h1.Connect(context.Background(), ai) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to negotiate security protocol") + require.NoError(t, h2.Connect(context.Background(), ai)) +} diff --git a/options.go b/options.go index ac5ce84d7..3d1be2022 100644 --- a/options.go +++ b/options.go @@ -16,6 +16,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autorelay" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" @@ -63,17 +65,27 @@ func ListenAddrs(addrs ...ma.Multiaddr) Option { // * Host // * Network // * Peerstore -func Security(name string, tpt interface{}) Option { - stpt, err := config.SecurityConstructor(tpt) - err = traceError(err, 1) +func Security(name string, constructor interface{}) Option { return func(cfg *Config) error { - if err != nil { - return err - } if cfg.Insecure { return fmt.Errorf("cannot use security transports with an insecure libp2p configuration") } - cfg.SecurityTransports = append(cfg.SecurityTransports, config.MsSecC{SecC: stpt, ID: name}) + fxName := fmt.Sprintf(`name:"%s"`, name) + // provide the name of the security transport + cfg.SecurityTransports = append(cfg.SecurityTransports, + fx.Provide(fx.Annotate( + func() protocol.ID { return protocol.ID(name) }, + fx.ResultTags(fxName), + )), + ) + cfg.SecurityTransports = append(cfg.SecurityTransports, + fx.Provide(fx.Annotate( + constructor, + fx.ParamTags(fxName), + fx.As(new(sec.SecureTransport)), + fx.ResultTags(`group:"security"`), + )), + ) return nil } } diff --git a/p2p/net/conn-security-multistream/ssms_test.go b/p2p/net/conn-security-multistream/ssms_test.go index 5aa5db352..3ccf4a7f2 100644 --- a/p2p/net/conn-security-multistream/ssms_test.go +++ b/p2p/net/conn-security-multistream/ssms_test.go @@ -44,8 +44,8 @@ func TestCommonProto(t *testing.T) { var at, bt SSMuxer - atInsecure := insecure.NewWithIdentity(idA, privA) - btInsecure := insecure.NewWithIdentity(idB, privB) + atInsecure := insecure.NewWithIdentity(insecure.ID, idA, privA) + btInsecure := insecure.NewWithIdentity(insecure.ID, idB, privB) at.AddTransport("/plaintext/1.0.0", atInsecure) bt.AddTransport("/plaintext/1.1.0", btInsecure) bt.AddTransport("/plaintext/1.0.0", btInsecure) @@ -88,8 +88,8 @@ func TestNoCommonProto(t *testing.T) { privB, idB := newPeer(t) var at, bt SSMuxer - atInsecure := insecure.NewWithIdentity(idA, privA) - btInsecure := insecure.NewWithIdentity(idB, privB) + atInsecure := insecure.NewWithIdentity(insecure.ID, idA, privA) + btInsecure := insecure.NewWithIdentity(insecure.ID, idB, privB) at.AddTransport("/plaintext/1.0.0", atInsecure) bt.AddTransport("/plaintext/1.1.0", btInsecure) diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index b6c6be0ca..728c81da3 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -77,7 +77,7 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, pk)) stMuxer := msmux.NewBlankTransport() stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 2573c51ba..a28b488ef 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -103,7 +103,7 @@ func GenUpgrader(t *testing.T, n *swarm.Swarm, connGater connmgr.ConnectionGater id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, pk)) stMuxer := msmux.NewBlankTransport() stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go index 51fa6f71b..d39d36022 100644 --- a/p2p/net/upgrader/upgrader_test.go +++ b/p2p/net/upgrader/upgrader_test.go @@ -44,7 +44,7 @@ func createUpgraderWithMuxer(t *testing.T, muxer network.Multiplexer, rcmgr netw require.NoError(t, err) id, err := peer.IDFromPrivateKey(priv) require.NoError(t, err) - u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(id, priv)}, muxer, nil, rcmgr, connGater, opts...) + u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(insecure.ID, id, priv)}, muxer, nil, rcmgr, connGater, opts...) require.NoError(t, err) return id, u } diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 836275b95..d59a1cb97 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -39,7 +39,7 @@ func makeTransport(b *testing.B) *Transport { if err != nil { b.Fatal(err) } - tpt, err := New(priv, nil) + tpt, err := New(ID, priv, nil) if err != nil { b.Fatalf("error constructing transport: %v", err) } diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index c42271bd9..0f26f3fa8 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -6,6 +6,7 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" @@ -71,6 +72,8 @@ type SessionTransport struct { prologue []byte disablePeerIDCheck bool + protocolID protocol.ID + initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler } @@ -92,3 +95,7 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true, !i.disablePeerIDCheck) } + +func (i *SessionTransport) ID() protocol.ID { + return i.protocolID +} diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index 687892950..e436c46aa 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -15,24 +15,21 @@ import ( ) // ID is the protocol ID for noise -const ( - ID = "/noise" - maxProtoNum = 100 -) - -var _ sec.SecureTransport = &Transport{} +const ID = "/noise" +const maxProtoNum = 100 -// Transport implements the interface sec.SecureTransport -// https://godoc.org/github.com/libp2p/go-libp2p/core/sec#SecureConn type Transport struct { + protocolID protocol.ID localID peer.ID privateKey crypto.PrivKey muxers []string } +var _ sec.SecureTransport = &Transport{} + // New creates a new Noise transport using the given private key as its // libp2p identity key. -func New(privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { +func New(id protocol.ID, privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err @@ -44,6 +41,7 @@ func New(privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { } return &Transport{ + protocolID: id, localID: localID, privateKey: privkey, muxers: smuxers, @@ -75,7 +73,7 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee } func (t *Transport) WithSessionOptions(opts ...SessionOption) (*SessionTransport, error) { - st := &SessionTransport{t: t} + st := &SessionTransport{t: t, protocolID: t.protocolID} for _, opt := range opts { if err := opt(st); err != nil { return nil, err @@ -84,6 +82,10 @@ func (t *Transport) WithSessionOptions(opts ...SessionOption) (*SessionTransport return st, nil } +func (t *Transport) ID() protocol.ID { + return t.protocolID +} + func matchMuxers(initiatorMuxers, responderMuxers []string) string { for _, muxer := range responderMuxers { for _, initMuxer := range initiatorMuxers { diff --git a/p2p/security/tls/cmd/tlsdiag/client.go b/p2p/security/tls/cmd/tlsdiag/client.go index 2292bfe0e..3868afebb 100644 --- a/p2p/security/tls/cmd/tlsdiag/client.go +++ b/p2p/security/tls/cmd/tlsdiag/client.go @@ -34,7 +34,7 @@ func StartClient() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv, nil) + tp, err := libp2ptls.New(libp2ptls.ID, priv, nil) if err != nil { return err } diff --git a/p2p/security/tls/cmd/tlsdiag/server.go b/p2p/security/tls/cmd/tlsdiag/server.go index 05e4be3f1..76c45a155 100644 --- a/p2p/security/tls/cmd/tlsdiag/server.go +++ b/p2p/security/tls/cmd/tlsdiag/server.go @@ -27,7 +27,7 @@ func StartServer() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv, nil) + tp, err := libp2ptls.New(libp2ptls.ID, priv, nil) if err != nil { return err } diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 695f64846..754a33513 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -26,21 +26,25 @@ const ID = "/tls/1.0.0" type Transport struct { identity *Identity - localPeer peer.ID - privKey ci.PrivKey - muxers []protocol.ID + localPeer peer.ID + privKey ci.PrivKey + muxers []protocol.ID + protocolID protocol.ID } +var _ sec.SecureTransport = &Transport{} + // New creates a TLS encrypted transport -func New(key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { - id, err := peer.IDFromPrivateKey(key) +func New(id protocol.ID, key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { + localPeer, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err } t := &Transport{ - localPeer: id, - privKey: key, - muxers: muxers, + protocolID: id, + localPeer: localPeer, + privKey: key, + muxers: muxers, } identity, err := NewIdentity(key) @@ -51,8 +55,6 @@ func New(key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { return t, nil } -var _ sec.SecureTransport = &Transport{} - // SecureInbound runs the TLS handshake as a server. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { @@ -148,3 +150,7 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se connectionState: network.ConnectionState{NextProto: nextProto}, }, nil } + +func (t *Transport) ID() protocol.ID { + return t.protocolID +} diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 59fa8bdae..8c9fa7ced 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -126,9 +126,9 @@ func TestHandshakeSucceeds(t *testing.T) { } // Use standard transports with default TLS configuration - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) t.Run("standard TLS with extension not critical", func(t *testing.T) { @@ -240,9 +240,9 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { // Iterate through the NextProto combinations. for _, test := range tests { - clientTransport, err := New(clientKey, test.clientProtos) + clientTransport, err := New(ID, clientKey, test.clientProtos) require.NoError(t, err) - serverTransport, err := New(serverKey, test.serverProtos) + serverTransport, err := New(ID, serverKey, test.serverProtos) require.NoError(t, err) t.Run("TLS handshake with ALPN extension", func(t *testing.T) { @@ -268,9 +268,9 @@ func TestHandshakeConnectionCancellations(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) t.Run("cancel outgoing connection", func(t *testing.T) { @@ -320,9 +320,9 @@ func TestPeerIDMismatch(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) t.Run("for outgoing connections", func(t *testing.T) { @@ -597,9 +597,9 @@ func TestInvalidCerts(t *testing.T) { tr := transforms[i] t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) tr.apply(clientTransport.identity) @@ -640,10 +640,10 @@ func TestInvalidCerts(t *testing.T) { }) t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) tr.apply(serverTransport.identity) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) clientInsecureConn, serverInsecureConn := connect(t) diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index b438f6eae..eec1657dd 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -153,6 +153,6 @@ func makeInsecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { id, err := peer.IDFromPrivateKey(priv) require.NoError(t, err) var secMuxer csms.SSMuxer - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, priv)) return id, &secMuxer } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index f4a457c3d..714fe89a6 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -67,7 +67,7 @@ func newInsecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Fatal(err) } var secMuxer csms.SSMuxer - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, priv)) return id, &secMuxer } @@ -82,7 +82,7 @@ func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Fatal(err) } var secMuxer csms.SSMuxer - noiseTpt, err := noise.New(priv, nil) + noiseTpt, err := noise.New(noise.ID, priv, nil) require.NoError(t, err) secMuxer.AddTransport(noise.ID, noiseTpt) return id, &secMuxer diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 3d902b4ed..bc1b5e143 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -110,7 +110,7 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } } - n, err := noise.New(key, nil) + n, err := noise.New(noise.ID, key, nil) if err != nil { return nil, err }