|
|
@ -5,15 +5,18 @@ import ( |
|
|
|
"errors" |
|
|
|
"io" |
|
|
|
"os" |
|
|
|
"sync/atomic" |
|
|
|
"testing" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" |
|
|
|
"github.com/libp2p/go-msgio/pbio" |
|
|
|
|
|
|
|
"github.com/libp2p/go-libp2p/core/network" |
|
|
|
|
|
|
|
"github.com/pion/datachannel" |
|
|
|
"github.com/pion/webrtc/v3" |
|
|
|
"github.com/stretchr/testify/assert" |
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
) |
|
|
|
|
|
|
@ -24,19 +27,20 @@ type detachedChan struct { |
|
|
|
|
|
|
|
func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) { |
|
|
|
s := webrtc.SettingEngine{} |
|
|
|
s.SetIncludeLoopbackCandidate(true) |
|
|
|
s.DetachDataChannels() |
|
|
|
api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) |
|
|
|
|
|
|
|
offerPC, err := api.NewPeerConnection(webrtc.Configuration{}) |
|
|
|
require.NoError(t, err) |
|
|
|
t.Cleanup(func() { offerPC.Close() }) |
|
|
|
offerRWCChan := make(chan datachannel.ReadWriteCloser, 1) |
|
|
|
offerRWCChan := make(chan detachedChan, 1) |
|
|
|
offerDC, err := offerPC.CreateDataChannel("data", nil) |
|
|
|
require.NoError(t, err) |
|
|
|
offerDC.OnOpen(func() { |
|
|
|
rwc, err := offerDC.Detach() |
|
|
|
require.NoError(t, err) |
|
|
|
offerRWCChan <- rwc |
|
|
|
offerRWCChan <- detachedChan{rwc: rwc, dc: offerDC} |
|
|
|
}) |
|
|
|
|
|
|
|
answerPC, err := api.NewPeerConnection(webrtc.Configuration{}) |
|
|
@ -91,15 +95,15 @@ func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) { |
|
|
|
require.NoError(t, offerPC.SetRemoteDescription(answer)) |
|
|
|
require.NoError(t, answerPC.SetLocalDescription(answer)) |
|
|
|
|
|
|
|
return <-answerChan, detachedChan{rwc: <-offerRWCChan, dc: offerDC} |
|
|
|
return <-answerChan, <-offerRWCChan |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamSimpleReadWriteClose(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
var clientDone, serverDone bool |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { clientDone = true }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() { serverDone = true }) |
|
|
|
var clientDone, serverDone atomic.Bool |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) |
|
|
|
|
|
|
|
// send a foobar from the client
|
|
|
|
n, err := clientStr.Write([]byte("foobar")) |
|
|
@ -109,7 +113,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { |
|
|
|
// writing after closing should error
|
|
|
|
_, err = clientStr.Write([]byte("foobar")) |
|
|
|
require.Error(t, err) |
|
|
|
require.False(t, clientDone) |
|
|
|
require.False(t, clientDone.Load()) |
|
|
|
|
|
|
|
// now read all the data on the server side
|
|
|
|
b, err := io.ReadAll(serverStr) |
|
|
@ -119,19 +123,26 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { |
|
|
|
n, err = serverStr.Read(make([]byte, 10)) |
|
|
|
require.Zero(t, n) |
|
|
|
require.ErrorIs(t, err, io.EOF) |
|
|
|
require.False(t, serverDone) |
|
|
|
require.False(t, serverDone.Load()) |
|
|
|
|
|
|
|
// send something back
|
|
|
|
_, err = serverStr.Write([]byte("lorem ipsum")) |
|
|
|
require.NoError(t, err) |
|
|
|
require.NoError(t, serverStr.CloseWrite()) |
|
|
|
require.True(t, serverDone) |
|
|
|
|
|
|
|
// and read it at the client
|
|
|
|
require.False(t, clientDone) |
|
|
|
require.False(t, clientDone.Load()) |
|
|
|
b, err = io.ReadAll(clientStr) |
|
|
|
require.NoError(t, err) |
|
|
|
require.Equal(t, []byte("lorem ipsum"), b) |
|
|
|
require.True(t, clientDone) |
|
|
|
|
|
|
|
// stream is only cleaned up on calling Close or Reset
|
|
|
|
clientStr.Close() |
|
|
|
serverStr.Close() |
|
|
|
require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) |
|
|
|
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
|
|
|
|
require.NoError(t, serverStr.Close()) |
|
|
|
require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamPartialReads(t *testing.T) { |
|
|
@ -201,7 +212,7 @@ func TestStreamReadReturnsOnClose(t *testing.T) { |
|
|
|
_, err := clientStr.Read([]byte{0}) |
|
|
|
errChan <- err |
|
|
|
}() |
|
|
|
time.Sleep(50 * time.Millisecond) // give the Read call some time to hit the loop
|
|
|
|
time.Sleep(100 * time.Millisecond) // give the Read call some time to hit the loop
|
|
|
|
require.NoError(t, clientStr.Close()) |
|
|
|
select { |
|
|
|
case err := <-errChan: |
|
|
@ -209,14 +220,17 @@ func TestStreamReadReturnsOnClose(t *testing.T) { |
|
|
|
case <-time.After(500 * time.Millisecond): |
|
|
|
t.Fatal("timeout") |
|
|
|
} |
|
|
|
|
|
|
|
_, err := clientStr.Read([]byte{0}) |
|
|
|
require.ErrorIs(t, err, network.ErrReset) |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamResets(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
var clientDone, serverDone bool |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { clientDone = true }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() { serverDone = true }) |
|
|
|
var clientDone, serverDone atomic.Bool |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) |
|
|
|
|
|
|
|
// send a foobar from the client
|
|
|
|
_, err := clientStr.Write([]byte("foobar")) |
|
|
@ -224,7 +238,7 @@ func TestStreamResets(t *testing.T) { |
|
|
|
_, err = serverStr.Write([]byte("lorem ipsum")) |
|
|
|
require.NoError(t, err) |
|
|
|
require.NoError(t, clientStr.Reset()) // resetting resets both directions
|
|
|
|
require.True(t, clientDone) |
|
|
|
require.True(t, clientDone.Load()) |
|
|
|
// attempting to write more data should result in a reset error
|
|
|
|
_, err = clientStr.Write([]byte("foobar")) |
|
|
|
require.ErrorIs(t, err, network.ErrReset) |
|
|
@ -234,7 +248,7 @@ func TestStreamResets(t *testing.T) { |
|
|
|
require.ErrorIs(t, err, network.ErrReset) |
|
|
|
|
|
|
|
// read the data on the server side
|
|
|
|
require.False(t, serverDone) |
|
|
|
require.False(t, serverDone.Load()) |
|
|
|
b, err = io.ReadAll(serverStr) |
|
|
|
require.Equal(t, []byte("foobar"), b) |
|
|
|
require.ErrorIs(t, err, network.ErrReset) |
|
|
@ -242,7 +256,10 @@ func TestStreamResets(t *testing.T) { |
|
|
|
_, err := serverStr.Write([]byte("foobar")) |
|
|
|
return errors.Is(err, network.ErrReset) |
|
|
|
}, time.Second, 50*time.Millisecond) |
|
|
|
require.True(t, serverDone) |
|
|
|
serverStr.Close() |
|
|
|
require.Eventually(t, func() bool { |
|
|
|
return serverDone.Load() |
|
|
|
}, time.Second, 50*time.Millisecond) |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamReadDeadlineAsync(t *testing.T) { |
|
|
@ -305,3 +322,207 @@ func TestStreamWriteDeadlineAsync(t *testing.T) { |
|
|
|
require.GreaterOrEqual(t, took, timeout) |
|
|
|
require.LessOrEqual(t, took, timeout*3/2) |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamReadAfterClose(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {}) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() {}) |
|
|
|
|
|
|
|
serverStr.Close() |
|
|
|
b := make([]byte, 1) |
|
|
|
_, err := clientStr.Read(b) |
|
|
|
require.Equal(t, io.EOF, err) |
|
|
|
_, err = clientStr.Read(nil) |
|
|
|
require.Equal(t, io.EOF, err) |
|
|
|
|
|
|
|
client, server = getDetachedDataChannels(t) |
|
|
|
|
|
|
|
clientStr = newStream(client.dc, client.rwc, func() {}) |
|
|
|
serverStr = newStream(server.dc, server.rwc, func() {}) |
|
|
|
|
|
|
|
serverStr.Reset() |
|
|
|
b = make([]byte, 1) |
|
|
|
_, err = clientStr.Read(b) |
|
|
|
require.ErrorIs(t, err, network.ErrReset) |
|
|
|
_, err = clientStr.Read(nil) |
|
|
|
require.ErrorIs(t, err, network.ErrReset) |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamCloseAfterFINACK(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
done := make(chan bool, 1) |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() {}) |
|
|
|
|
|
|
|
go func() { |
|
|
|
done <- true |
|
|
|
err := clientStr.Close() |
|
|
|
assert.NoError(t, err) |
|
|
|
}() |
|
|
|
<-done |
|
|
|
|
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
t.Fatalf("Close should not have completed without processing FIN_ACK") |
|
|
|
case <-time.After(2 * time.Second): |
|
|
|
} |
|
|
|
|
|
|
|
b := make([]byte, 1) |
|
|
|
_, err := serverStr.Read(b) |
|
|
|
require.Error(t, err) |
|
|
|
require.ErrorIs(t, err, io.EOF) |
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
case <-time.After(3 * time.Second): |
|
|
|
t.Errorf("Close should have completed") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half
|
|
|
|
// of the stream is closed.
|
|
|
|
func TestStreamFinAckAfterStopSending(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
done := make(chan bool, 1) |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() {}) |
|
|
|
|
|
|
|
go func() { |
|
|
|
clientStr.CloseRead() |
|
|
|
clientStr.Write([]byte("hello world")) |
|
|
|
done <- true |
|
|
|
err := clientStr.Close() |
|
|
|
assert.NoError(t, err) |
|
|
|
}() |
|
|
|
<-done |
|
|
|
|
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
t.Errorf("Close should not have completed without processing FIN_ACK") |
|
|
|
case <-time.After(500 * time.Millisecond): |
|
|
|
} |
|
|
|
|
|
|
|
// serverStr has write half closed and read half open
|
|
|
|
// serverStr should still send FIN_ACK
|
|
|
|
b := make([]byte, 24) |
|
|
|
_, err := serverStr.Read(b) |
|
|
|
require.NoError(t, err) |
|
|
|
serverStr.Close() // Sends stop_sending, fin
|
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
case <-time.After(5 * time.Second): |
|
|
|
t.Fatalf("Close should have completed") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamConcurrentClose(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
start := make(chan bool, 2) |
|
|
|
done := make(chan bool, 2) |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true }) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() { done <- true }) |
|
|
|
|
|
|
|
go func() { |
|
|
|
start <- true |
|
|
|
clientStr.Close() |
|
|
|
}() |
|
|
|
go func() { |
|
|
|
start <- true |
|
|
|
serverStr.Close() |
|
|
|
}() |
|
|
|
<-start |
|
|
|
<-start |
|
|
|
|
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
case <-time.After(2 * time.Second): |
|
|
|
t.Fatalf("concurrent close should succeed quickly") |
|
|
|
} |
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
case <-time.After(2 * time.Second): |
|
|
|
t.Fatalf("concurrent close should succeed quickly") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamResetAfterClose(t *testing.T) { |
|
|
|
client, _ := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
done := make(chan bool, 2) |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true }) |
|
|
|
clientStr.Close() |
|
|
|
|
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
t.Fatalf("Close shouldn't run cleanup immediately") |
|
|
|
case <-time.After(500 * time.Millisecond): |
|
|
|
} |
|
|
|
|
|
|
|
clientStr.Reset() |
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
case <-time.After(2 * time.Second): |
|
|
|
t.Fatalf("Reset should run callback immediately") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamDataChannelCloseOnFINACK(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
done := make(chan bool, 1) |
|
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true }) |
|
|
|
|
|
|
|
clientStr.Close() |
|
|
|
|
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
t.Fatalf("Close shouldn't run cleanup immediately") |
|
|
|
case <-time.After(500 * time.Millisecond): |
|
|
|
} |
|
|
|
|
|
|
|
serverWriter := pbio.NewDelimitedWriter(server.rwc) |
|
|
|
err := serverWriter.WriteMsg(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}) |
|
|
|
require.NoError(t, err) |
|
|
|
select { |
|
|
|
case <-done: |
|
|
|
case <-time.After(2 * time.Second): |
|
|
|
t.Fatalf("Callback should be run on reading FIN_ACK") |
|
|
|
} |
|
|
|
b := make([]byte, 100) |
|
|
|
N := 0 |
|
|
|
for { |
|
|
|
n, err := server.rwc.Read(b) |
|
|
|
N += n |
|
|
|
if err != nil { |
|
|
|
require.ErrorIs(t, err, io.EOF) |
|
|
|
break |
|
|
|
} |
|
|
|
} |
|
|
|
require.Less(t, N, 10) |
|
|
|
} |
|
|
|
|
|
|
|
func TestStreamChunking(t *testing.T) { |
|
|
|
client, server := getDetachedDataChannels(t) |
|
|
|
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {}) |
|
|
|
serverStr := newStream(server.dc, server.rwc, func() {}) |
|
|
|
|
|
|
|
const N = (16 << 10) + 1000 |
|
|
|
go func() { |
|
|
|
data := make([]byte, N) |
|
|
|
_, err := clientStr.Write(data) |
|
|
|
require.NoError(t, err) |
|
|
|
}() |
|
|
|
|
|
|
|
data := make([]byte, N) |
|
|
|
n, err := serverStr.Read(data) |
|
|
|
require.NoError(t, err) |
|
|
|
require.LessOrEqual(t, n, 16<<10) |
|
|
|
|
|
|
|
nn, err := serverStr.Read(data) |
|
|
|
require.NoError(t, err) |
|
|
|
require.Equal(t, nn+n, N) |
|
|
|
} |
|
|
|