|
|
@ -78,10 +78,10 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess |
|
|
|
done := make(chan struct{}) |
|
|
|
go func() { |
|
|
|
defer close(done) |
|
|
|
initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) |
|
|
|
initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID) |
|
|
|
}() |
|
|
|
|
|
|
|
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "") |
|
|
|
respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "") |
|
|
|
<-done |
|
|
|
|
|
|
|
if initErr != nil { |
|
|
@ -171,7 +171,7 @@ func TestPeerIDMatch(t *testing.T) { |
|
|
|
done := make(chan struct{}) |
|
|
|
go func() { |
|
|
|
defer close(done) |
|
|
|
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) |
|
|
|
conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) |
|
|
|
assert.NoError(t, err) |
|
|
|
assert.Equal(t, conn.RemotePeer(), respTransport.localID) |
|
|
|
b := make([]byte, 6) |
|
|
@ -180,7 +180,7 @@ func TestPeerIDMatch(t *testing.T) { |
|
|
|
assert.Equal(t, b, []byte("foobar")) |
|
|
|
}() |
|
|
|
|
|
|
|
conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID) |
|
|
|
conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID) |
|
|
|
require.NoError(t, err) |
|
|
|
require.Equal(t, conn.RemotePeer(), initTransport.localID) |
|
|
|
_, err = conn.Write([]byte("foobar")) |
|
|
@ -194,11 +194,11 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { |
|
|
|
|
|
|
|
errChan := make(chan error) |
|
|
|
go func() { |
|
|
|
_, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id") |
|
|
|
_, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id") |
|
|
|
errChan <- err |
|
|
|
}() |
|
|
|
|
|
|
|
_, err := respTransport.SecureInbound(context.TODO(), resp, "") |
|
|
|
_, err := respTransport.SecureInbound(context.Background(), resp, "") |
|
|
|
require.Error(t, err) |
|
|
|
|
|
|
|
initErr := <-errChan |
|
|
@ -214,13 +214,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { |
|
|
|
done := make(chan struct{}) |
|
|
|
go func() { |
|
|
|
defer close(done) |
|
|
|
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) |
|
|
|
conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) |
|
|
|
assert.NoError(t, err) |
|
|
|
_, err = conn.Read([]byte{0}) |
|
|
|
assert.Error(t, err) |
|
|
|
}() |
|
|
|
|
|
|
|
_, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id") |
|
|
|
_, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id") |
|
|
|
require.Error(t, err, "expected responder to fail with peer ID mismatch error") |
|
|
|
<-done |
|
|
|
} |
|
|
@ -387,7 +387,7 @@ func TestPrologueMatches(t *testing.T) { |
|
|
|
tpt, err := initTransport. |
|
|
|
WithSessionOptions(Prologue(commonPrologue)) |
|
|
|
require.NoError(t, err) |
|
|
|
conn, err := tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID) |
|
|
|
conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) |
|
|
|
require.NoError(t, err) |
|
|
|
defer conn.Close() |
|
|
|
}() |
|
|
@ -395,7 +395,7 @@ func TestPrologueMatches(t *testing.T) { |
|
|
|
tpt, err := respTransport. |
|
|
|
WithSessionOptions(Prologue(commonPrologue)) |
|
|
|
require.NoError(t, err) |
|
|
|
conn, err := tpt.SecureInbound(context.TODO(), respConn, "") |
|
|
|
conn, err := tpt.SecureInbound(context.Background(), respConn, "") |
|
|
|
require.NoError(t, err) |
|
|
|
defer conn.Close() |
|
|
|
<-done |
|
|
@ -415,14 +415,125 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { |
|
|
|
tpt, err := initTransport. |
|
|
|
WithSessionOptions(Prologue(initPrologue)) |
|
|
|
require.NoError(t, err) |
|
|
|
_, err = tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID) |
|
|
|
_, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) |
|
|
|
require.Error(t, err) |
|
|
|
}() |
|
|
|
|
|
|
|
tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) |
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
_, err = tpt.SecureInbound(context.TODO(), respConn, "") |
|
|
|
_, err = tpt.SecureInbound(context.Background(), respConn, "") |
|
|
|
require.Error(t, err) |
|
|
|
<-done |
|
|
|
} |
|
|
|
|
|
|
|
type earlyDataHandler struct { |
|
|
|
send func(context.Context, net.Conn, peer.ID) []byte |
|
|
|
received func(context.Context, net.Conn, []byte) error |
|
|
|
} |
|
|
|
|
|
|
|
func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) []byte { |
|
|
|
if e.send == nil { |
|
|
|
return nil |
|
|
|
} |
|
|
|
return e.send(ctx, conn, id) |
|
|
|
} |
|
|
|
|
|
|
|
func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []byte) error { |
|
|
|
if e.received == nil { |
|
|
|
return nil |
|
|
|
} |
|
|
|
return e.received(ctx, conn, data) |
|
|
|
} |
|
|
|
|
|
|
|
func TestEarlyDataAccepted(t *testing.T) { |
|
|
|
var receivedEarlyData []byte |
|
|
|
serverEDH := &earlyDataHandler{ |
|
|
|
received: func(_ context.Context, _ net.Conn, data []byte) error { |
|
|
|
receivedEarlyData = data |
|
|
|
return nil |
|
|
|
}, |
|
|
|
} |
|
|
|
clientEDH := &earlyDataHandler{ |
|
|
|
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, |
|
|
|
} |
|
|
|
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) |
|
|
|
require.NoError(t, err) |
|
|
|
tpt := newTestTransport(t, crypto.Ed25519, 2048) |
|
|
|
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) |
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
initConn, respConn := newConnPair(t) |
|
|
|
|
|
|
|
errChan := make(chan error) |
|
|
|
go func() { |
|
|
|
_, err := respTransport.SecureInbound(context.Background(), initConn, "") |
|
|
|
errChan <- err |
|
|
|
}() |
|
|
|
|
|
|
|
conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) |
|
|
|
require.NoError(t, err) |
|
|
|
defer conn.Close() |
|
|
|
|
|
|
|
require.Equal(t, []byte("foobar"), receivedEarlyData) |
|
|
|
} |
|
|
|
|
|
|
|
func TestEarlyDataRejected(t *testing.T) { |
|
|
|
serverEDH := &earlyDataHandler{ |
|
|
|
received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") }, |
|
|
|
} |
|
|
|
clientEDH := &earlyDataHandler{ |
|
|
|
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, |
|
|
|
} |
|
|
|
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) |
|
|
|
require.NoError(t, err) |
|
|
|
tpt := newTestTransport(t, crypto.Ed25519, 2048) |
|
|
|
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) |
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
initConn, respConn := newConnPair(t) |
|
|
|
|
|
|
|
errChan := make(chan error) |
|
|
|
go func() { |
|
|
|
_, err := respTransport.SecureInbound(context.Background(), initConn, "") |
|
|
|
errChan <- err |
|
|
|
}() |
|
|
|
|
|
|
|
_, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) |
|
|
|
require.Error(t, err) |
|
|
|
|
|
|
|
select { |
|
|
|
case <-time.After(500 * time.Millisecond): |
|
|
|
t.Fatal("timeout") |
|
|
|
case err := <-errChan: |
|
|
|
require.EqualError(t, err, "nope") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestEarlyDataRejectedWithNoHandler(t *testing.T) { |
|
|
|
clientEDH := &earlyDataHandler{ |
|
|
|
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, |
|
|
|
} |
|
|
|
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) |
|
|
|
require.NoError(t, err) |
|
|
|
respTransport := newTestTransport(t, crypto.Ed25519, 2048) |
|
|
|
|
|
|
|
initConn, respConn := newConnPair(t) |
|
|
|
|
|
|
|
errChan := make(chan error) |
|
|
|
go func() { |
|
|
|
_, err := respTransport.SecureInbound(context.Background(), initConn, "") |
|
|
|
errChan <- err |
|
|
|
}() |
|
|
|
|
|
|
|
_, err = initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID) |
|
|
|
require.Error(t, err) |
|
|
|
|
|
|
|
select { |
|
|
|
case <-time.After(500 * time.Millisecond): |
|
|
|
t.Fatal("timeout") |
|
|
|
case err := <-errChan: |
|
|
|
require.Error(t, err) |
|
|
|
require.Contains(t, err.Error(), "received unexpected early data") |
|
|
|
} |
|
|
|
} |
|
|
|