|
|
@ -41,62 +41,57 @@ func (s *secureSession) runHandshake(ctx context.Context) error { |
|
|
|
return fmt.Errorf("error initializing handshake state: %s", err) |
|
|
|
} |
|
|
|
|
|
|
|
s.ns.hs = hs |
|
|
|
s.ns.localStatic = kp |
|
|
|
|
|
|
|
payload, err := s.generateHandshakePayload() |
|
|
|
payload, err := s.generateHandshakePayload(kp) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
if s.initiator { |
|
|
|
// stage 0 //
|
|
|
|
err = s.sendHandshakeMessage(nil) |
|
|
|
err = s.sendHandshakeMessage(hs, nil) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("error sending handshake message: %s", err) |
|
|
|
} |
|
|
|
|
|
|
|
// stage 1 //
|
|
|
|
plaintext, err := s.readHandshakeMessage() |
|
|
|
plaintext, err := s.readHandshakeMessage(hs) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("error reading handshake message: %s", err) |
|
|
|
} |
|
|
|
err = s.handleRemoteHandshakePayload(plaintext) |
|
|
|
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
// stage 2 //
|
|
|
|
err = s.sendHandshakeMessage(payload) |
|
|
|
err = s.sendHandshakeMessage(hs, payload) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("error sending handshake message: %s", err) |
|
|
|
} |
|
|
|
} else { |
|
|
|
// stage 0 //
|
|
|
|
plaintext, err := s.readHandshakeMessage() |
|
|
|
plaintext, err := s.readHandshakeMessage(hs) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("error reading handshake message: %s", err) |
|
|
|
} |
|
|
|
|
|
|
|
// stage 1 //
|
|
|
|
err = s.sendHandshakeMessage(payload) |
|
|
|
err = s.sendHandshakeMessage(hs, payload) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("error sending handshake message: %s", err) |
|
|
|
} |
|
|
|
|
|
|
|
// stage 2 //
|
|
|
|
plaintext, err = s.readHandshakeMessage() |
|
|
|
plaintext, err = s.readHandshakeMessage(hs) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("error reading handshake message: %s", err) |
|
|
|
} |
|
|
|
err = s.handleRemoteHandshakePayload(plaintext) |
|
|
|
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// we can discard the handshake state once the handshake completes
|
|
|
|
s.ns.hs = nil |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
@ -105,21 +100,20 @@ func (s *secureSession) runHandshake(ctx context.Context) error { |
|
|
|
// It sets the initial cipher states that will be used to protect traffic after the handshake.
|
|
|
|
func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) { |
|
|
|
if s.initiator { |
|
|
|
s.ns.enc = cs1 |
|
|
|
s.ns.dec = cs2 |
|
|
|
s.enc = cs1 |
|
|
|
s.dec = cs2 |
|
|
|
} else { |
|
|
|
s.ns.enc = cs2 |
|
|
|
s.ns.dec = cs1 |
|
|
|
s.enc = cs2 |
|
|
|
s.dec = cs1 |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// sendHandshakeMessage sends the next handshake message in the sequence.
|
|
|
|
// Only safe to call from runHandshake, as it depends on handshake state.
|
|
|
|
// If payload is non-empty, it will be included in the handshake message.
|
|
|
|
// If this is the final message in the sequence, calls setCipherStates
|
|
|
|
// to initialize cipher states.
|
|
|
|
func (s *secureSession) sendHandshakeMessage(payload []byte) error { |
|
|
|
buf, cs1, cs2, err := s.ns.hs.WriteMessage(nil, payload) |
|
|
|
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte) error { |
|
|
|
buf, cs1, cs2, err := hs.WriteMessage(nil, payload) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
@ -137,16 +131,15 @@ func (s *secureSession) sendHandshakeMessage(payload []byte) error { |
|
|
|
|
|
|
|
// readHandshakeMessage reads a message from the insecure conn and tries to
|
|
|
|
// process it as the expected next message in the handshake sequence.
|
|
|
|
// Only safe to call from runHandshake, as it depends on handshake state.
|
|
|
|
// If the message contains a payload, it will be decrypted and returned.
|
|
|
|
// If this is the final message in the sequence, calls setCipherStates
|
|
|
|
// to initialize cipher states.
|
|
|
|
func (s *secureSession) readHandshakeMessage() ([]byte, error) { |
|
|
|
func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) { |
|
|
|
raw, err := s.readMsgInsecure() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
msg, cs1, cs2, err := s.ns.hs.ReadMessage(nil, raw) |
|
|
|
msg, cs1, cs2, err := hs.ReadMessage(nil, raw) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
@ -158,8 +151,7 @@ func (s *secureSession) readHandshakeMessage() ([]byte, error) { |
|
|
|
|
|
|
|
// generateHandshakePayload creates a libp2p handshake payload with a
|
|
|
|
// signature of our static noise key.
|
|
|
|
// Must be called after the static key for the session has been generated.
|
|
|
|
func (s *secureSession) generateHandshakePayload() ([]byte, error) { |
|
|
|
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) { |
|
|
|
// setup libp2p keys
|
|
|
|
localKeyRaw, err := s.LocalPublicKey().Bytes() |
|
|
|
if err != nil { |
|
|
@ -167,7 +159,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) { |
|
|
|
} |
|
|
|
|
|
|
|
// sign noise data for payload
|
|
|
|
toSign := append([]byte(payloadSigPrefix), s.ns.localStatic.Public...) |
|
|
|
toSign := append([]byte(payloadSigPrefix), localStatic.Public...) |
|
|
|
signedPayload, err := s.localKey.Sign(toSign) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("error sigining handshake payload: %s", err) |
|
|
@ -186,8 +178,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) { |
|
|
|
|
|
|
|
// handleRemoteHandshakePayload unmarshals the handshake payload object sent
|
|
|
|
// by the remote peer and validates the signature against the peer's static Noise key.
|
|
|
|
// Only safe to call from runHandshake, as it depends on handshake state.
|
|
|
|
func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error { |
|
|
|
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error { |
|
|
|
// unmarshal payload
|
|
|
|
nhp := new(pb.NoiseHandshakePayload) |
|
|
|
err := proto.Unmarshal(payload, nhp) |
|
|
@ -212,7 +203,6 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error { |
|
|
|
|
|
|
|
// verify payload is signed by libp2p key
|
|
|
|
sig := nhp.GetIdentitySig() |
|
|
|
remoteStatic := s.ns.hs.PeerStatic() |
|
|
|
msg := append([]byte(payloadSigPrefix), remoteStatic...) |
|
|
|
ok, err := remotePubKey.Verify(msg, sig) |
|
|
|
if err != nil { |
|
|
|