|
|
@ -5,15 +5,16 @@ import ( |
|
|
|
"crypto/tls" |
|
|
|
"errors" |
|
|
|
"fmt" |
|
|
|
pb "github.com/marten-seemann/go-libp2p-webtransport/pb" |
|
|
|
"github.com/multiformats/go-multihash" |
|
|
|
"net" |
|
|
|
"net/http" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/libp2p/go-libp2p-core/connmgr" |
|
|
|
"github.com/libp2p/go-libp2p-core/network" |
|
|
|
tpt "github.com/libp2p/go-libp2p-core/transport" |
|
|
|
|
|
|
|
noise "github.com/libp2p/go-libp2p-noise" |
|
|
|
"github.com/libp2p/go-libp2p/core/connmgr" |
|
|
|
"github.com/libp2p/go-libp2p/core/network" |
|
|
|
tpt "github.com/libp2p/go-libp2p/core/transport" |
|
|
|
"github.com/libp2p/go-libp2p/p2p/security/noise" |
|
|
|
|
|
|
|
"github.com/lucas-clemente/quic-go/http3" |
|
|
|
"github.com/marten-seemann/webtransport-go" |
|
|
@ -27,10 +28,11 @@ const queueLen = 16 |
|
|
|
const handshakeTimeout = 10 * time.Second |
|
|
|
|
|
|
|
type listener struct { |
|
|
|
transport tpt.Transport |
|
|
|
noise *noise.Transport |
|
|
|
certManager *certManager |
|
|
|
staticTLSConf *tls.Config |
|
|
|
transport tpt.Transport |
|
|
|
noise *noise.Transport |
|
|
|
certManager *certManager |
|
|
|
tlsConf *tls.Config |
|
|
|
isStaticTLSConf bool |
|
|
|
|
|
|
|
rcmgr network.ResourceManager |
|
|
|
gater connmgr.ConnectionGater |
|
|
@ -67,23 +69,25 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
isStaticTLSConf := tlsConf != nil |
|
|
|
if tlsConf == nil { |
|
|
|
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { |
|
|
|
return certManager.GetConfig(), nil |
|
|
|
}} |
|
|
|
} |
|
|
|
ln := &listener{ |
|
|
|
transport: transport, |
|
|
|
noise: noise, |
|
|
|
certManager: certManager, |
|
|
|
staticTLSConf: tlsConf, |
|
|
|
rcmgr: rcmgr, |
|
|
|
gater: gater, |
|
|
|
queue: make(chan tpt.CapableConn, queueLen), |
|
|
|
serverClosed: make(chan struct{}), |
|
|
|
addr: udpConn.LocalAddr(), |
|
|
|
multiaddr: localMultiaddr, |
|
|
|
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, |
|
|
|
transport: transport, |
|
|
|
noise: noise, |
|
|
|
certManager: certManager, |
|
|
|
tlsConf: tlsConf, |
|
|
|
isStaticTLSConf: isStaticTLSConf, |
|
|
|
rcmgr: rcmgr, |
|
|
|
gater: gater, |
|
|
|
queue: make(chan tpt.CapableConn, queueLen), |
|
|
|
serverClosed: make(chan struct{}), |
|
|
|
addr: udpConn.LocalAddr(), |
|
|
|
multiaddr: localMultiaddr, |
|
|
|
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, |
|
|
|
} |
|
|
|
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) |
|
|
|
mux := http.NewServeMux() |
|
|
@ -184,7 +188,11 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (c |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
c, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") |
|
|
|
n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData))) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("failed to initialize Noise session: %w", err) |
|
|
|
} |
|
|
|
c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
@ -195,6 +203,31 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (c |
|
|
|
}, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (l *listener) checkEarlyData(b []byte) error { |
|
|
|
var msg pb.WebTransport |
|
|
|
if err := msg.Unmarshal(b); err != nil { |
|
|
|
fmt.Println(1) |
|
|
|
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) |
|
|
|
} |
|
|
|
|
|
|
|
if l.isStaticTLSConf { |
|
|
|
if len(msg.CertHashes) > 0 { |
|
|
|
return errors.New("using static TLS config, didn't expect any certificate hashes") |
|
|
|
} |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) |
|
|
|
for _, h := range msg.CertHashes { |
|
|
|
dh, err := multihash.Decode(h) |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("failed to decode hash: %w", err) |
|
|
|
} |
|
|
|
hashes = append(hashes, *dh) |
|
|
|
} |
|
|
|
return l.certManager.Verify(hashes) |
|
|
|
} |
|
|
|
|
|
|
|
func (l *listener) Addr() net.Addr { |
|
|
|
return l.addr |
|
|
|
} |
|
|
|