|
|
@ -1,9 +1,11 @@ |
|
|
|
package libp2pwebtransport |
|
|
|
|
|
|
|
import ( |
|
|
|
"bytes" |
|
|
|
"context" |
|
|
|
"crypto/tls" |
|
|
|
"crypto/x509" |
|
|
|
"errors" |
|
|
|
"fmt" |
|
|
|
"io" |
|
|
|
"sync" |
|
|
@ -196,19 +198,27 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p |
|
|
|
|
|
|
|
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
|
|
|
|
// The server will verify that it advertised all of these certificate hashes.
|
|
|
|
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))} |
|
|
|
for _, certHash := range certHashes { |
|
|
|
h, err := multihash.Encode(certHash.Digest, certHash.Code) |
|
|
|
var verified bool |
|
|
|
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error { |
|
|
|
decodedCertHashes, err := decodeCertHashesFromProtobuf(b) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("failed to encode certificate hash: %w", err) |
|
|
|
return err |
|
|
|
} |
|
|
|
msg.CertHashes = append(msg.CertHashes, h) |
|
|
|
for _, sent := range certHashes { |
|
|
|
var found bool |
|
|
|
for _, rcvd := range decodedCertHashes { |
|
|
|
if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) { |
|
|
|
found = true |
|
|
|
break |
|
|
|
} |
|
|
|
} |
|
|
|
if !found { |
|
|
|
return fmt.Errorf("missing cert hash: %v", sent) |
|
|
|
} |
|
|
|
msgBytes, err := msg.Marshal() |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) |
|
|
|
} |
|
|
|
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil)) |
|
|
|
verified = true |
|
|
|
return nil |
|
|
|
}), nil)) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("failed to create Noise transport: %w", err) |
|
|
|
} |
|
|
@ -216,12 +226,34 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
// The Noise handshake _should_ guarantee that our verification callback is called.
|
|
|
|
// Double-check just in case.
|
|
|
|
if !verified { |
|
|
|
return nil, errors.New("didn't verify") |
|
|
|
} |
|
|
|
return &connSecurityMultiaddrs{ |
|
|
|
ConnSecurity: c, |
|
|
|
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, |
|
|
|
}, nil |
|
|
|
} |
|
|
|
|
|
|
|
func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) { |
|
|
|
var msg pb.WebTransport |
|
|
|
if err := msg.Unmarshal(b); err != nil { |
|
|
|
return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err) |
|
|
|
} |
|
|
|
|
|
|
|
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) |
|
|
|
for _, h := range msg.CertHashes { |
|
|
|
dh, err := multihash.Decode(h) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("failed to decode hash: %w", err) |
|
|
|
} |
|
|
|
hashes = append(hashes, *dh) |
|
|
|
} |
|
|
|
return hashes, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (t *transport) CanDial(addr ma.Multiaddr) bool { |
|
|
|
var numHashes int |
|
|
|
ma.ForEach(addr, func(c ma.Component) bool { |
|
|
|