Browse Source

noise: implement an API to send and receive early data

pull/1728/head
Marten Seemann 2 years ago
parent
commit
8be14f480d
  1. 2
      p2p/security/noise/crypto_test.go
  2. 41
      p2p/security/noise/handshake.go
  3. 20
      p2p/security/noise/session.go
  4. 23
      p2p/security/noise/session_transport.go
  5. 4
      p2p/security/noise/transport.go
  6. 135
      p2p/security/noise/transport_test.go

2
p2p/security/noise/crypto_test.go

@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
init, resp := net.Pipe()
_ = resp.Close()
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, true)
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, true)
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")

41
p2p/security/noise/handshake.go

@ -10,17 +10,16 @@ import (
"runtime/debug"
"time"
"github.com/minio/sha256-simd"
"golang.org/x/crypto/chacha20poly1305"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/flynn/noise"
"github.com/gogo/protobuf/proto"
pool "github.com/libp2p/go-buffer-pool"
"github.com/minio/sha256-simd"
)
// payloadSigPrefix is prepended to our Noise static key before signing with
@ -89,10 +88,12 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if s.initiator {
// stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret.
// Handshake Msg Len = len(DH ephemeral key)
err = s.sendHandshakeMessage(hs, nil, hbuf)
if err != nil {
var ed []byte
if s.earlyDataHandler != nil {
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
if err := s.sendHandshakeMessage(hs, ed, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
@ -101,29 +102,34 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil {
return err
}
// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
return nil
} else {
// stage 0 //
// We don't expect any payload on the first message.
if _, err := s.readHandshakeMessage(hs); err != nil {
initialPayload, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil {
return err
}
} else if len(initialPayload) > 0 {
return fmt.Errorf("received unexpected early data (%d bytes)", len(initialPayload))
}
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
@ -132,13 +138,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
}
return nil
}
// setCipherStates sets the initial cipher states that will be used to protect

20
p2p/security/noise/session.go

@ -36,20 +36,22 @@ type secureSession struct {
dec *noise.CipherState
// noise prologue
prologue []byte
prologue []byte
earlyDataHandler EarlyDataHandler
}
// newSecureSession creates a Noise session over the given insecureConn Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiator bool) (*secureSession, error) {
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, edh EarlyDataHandler, initiator bool) (*secureSession, error) {
s := &secureSession{
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
prologue: prologue,
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
prologue: prologue,
earlyDataHandler: edh,
}
// the go-routine we create to run the handshake will

23
p2p/security/noise/session_transport.go

@ -22,6 +22,22 @@ func Prologue(prologue []byte) SessionOption {
}
}
// EarlyDataHandler allows attaching an (unencrypted) application payload to the first handshake message.
// While unencrypted, the integrity of this early data is retroactively authenticated on completion of the handshake.
type EarlyDataHandler interface {
// Send is called for the client before sending the first handshake message.
Send(context.Context, net.Conn, peer.ID) []byte
// Received is called for the server when the first handshake message from the client is received.
Received(context.Context, net.Conn, []byte) error
}
func EarlyData(h EarlyDataHandler) SessionOption {
return func(s *SessionTransport) error {
s.earlyDataHandler = h
return nil
}
}
var _ sec.SecureTransport = &SessionTransport{}
// SessionTransport can be used
@ -29,13 +45,14 @@ var _ sec.SecureTransport = &SessionTransport{}
type SessionTransport struct {
t *Transport
// options
prologue []byte
prologue []byte
earlyDataHandler EarlyDataHandler
}
// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, false)
c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
@ -47,5 +64,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn,
// SecureOutbound runs the Noise handshake as the initiator.
func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(i.t, ctx, insecure, p, i.prologue, true)
return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true)
}

4
p2p/security/noise/transport.go

@ -41,7 +41,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(t, ctx, insecure, p, nil, false)
c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
@ -53,7 +53,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer
// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, nil, true)
return newSecureSession(t, ctx, insecure, p, nil, nil, true)
}
func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) {

135
p2p/security/noise/transport_test.go

@ -78,10 +78,10 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
done := make(chan struct{})
go func() {
defer close(done)
initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID)
}()
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "")
<-done
if initErr != nil {
@ -171,7 +171,7 @@ func TestPeerIDMatch(t *testing.T) {
done := make(chan struct{})
go func() {
defer close(done)
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID)
assert.NoError(t, err)
assert.Equal(t, conn.RemotePeer(), respTransport.localID)
b := make([]byte, 6)
@ -180,7 +180,7 @@ func TestPeerIDMatch(t *testing.T) {
assert.Equal(t, b, []byte("foobar"))
}()
conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID)
conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID)
require.NoError(t, err)
require.Equal(t, conn.RemotePeer(), initTransport.localID)
_, err = conn.Write([]byte("foobar"))
@ -194,11 +194,11 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) {
errChan := make(chan error)
go func() {
_, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id")
_, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id")
errChan <- err
}()
_, err := respTransport.SecureInbound(context.TODO(), resp, "")
_, err := respTransport.SecureInbound(context.Background(), resp, "")
require.Error(t, err)
initErr := <-errChan
@ -214,13 +214,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
done := make(chan struct{})
go func() {
defer close(done)
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID)
assert.NoError(t, err)
_, err = conn.Read([]byte{0})
assert.Error(t, err)
}()
_, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id")
_, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id")
require.Error(t, err, "expected responder to fail with peer ID mismatch error")
<-done
}
@ -387,7 +387,7 @@ func TestPrologueMatches(t *testing.T) {
tpt, err := initTransport.
WithSessionOptions(Prologue(commonPrologue))
require.NoError(t, err)
conn, err := tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID)
conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID)
require.NoError(t, err)
defer conn.Close()
}()
@ -395,7 +395,7 @@ func TestPrologueMatches(t *testing.T) {
tpt, err := respTransport.
WithSessionOptions(Prologue(commonPrologue))
require.NoError(t, err)
conn, err := tpt.SecureInbound(context.TODO(), respConn, "")
conn, err := tpt.SecureInbound(context.Background(), respConn, "")
require.NoError(t, err)
defer conn.Close()
<-done
@ -415,14 +415,125 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) {
tpt, err := initTransport.
WithSessionOptions(Prologue(initPrologue))
require.NoError(t, err)
_, err = tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID)
_, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID)
require.Error(t, err)
}()
tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue))
require.NoError(t, err)
_, err = tpt.SecureInbound(context.TODO(), respConn, "")
_, err = tpt.SecureInbound(context.Background(), respConn, "")
require.Error(t, err)
<-done
}
type earlyDataHandler struct {
send func(context.Context, net.Conn, peer.ID) []byte
received func(context.Context, net.Conn, []byte) error
}
func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) []byte {
if e.send == nil {
return nil
}
return e.send(ctx, conn, id)
}
func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []byte) error {
if e.received == nil {
return nil
}
return e.received(ctx, conn, data)
}
func TestEarlyDataAccepted(t *testing.T) {
var receivedEarlyData []byte
serverEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error {
receivedEarlyData = data
return nil
},
}
clientEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH))
require.NoError(t, err)
initConn, respConn := newConnPair(t)
errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()
conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, []byte("foobar"), receivedEarlyData)
}
func TestEarlyDataRejected(t *testing.T) {
serverEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") },
}
clientEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH))
require.NoError(t, err)
initConn, respConn := newConnPair(t)
errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()
_, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.Error(t, err)
select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
case err := <-errChan:
require.EqualError(t, err, "nope")
}
}
func TestEarlyDataRejectedWithNoHandler(t *testing.T) {
clientEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
initConn, respConn := newConnPair(t)
errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()
_, err = initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID)
require.Error(t, err)
select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
case err := <-errChan:
require.Error(t, err)
require.Contains(t, err.Error(), "received unexpected early data")
}
}

Loading…
Cancel
Save