mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2024-12-26 08:35:17 +00:00
564 lines
16 KiB
Go
564 lines
16 KiB
Go
package libp2pwebrtc
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"errors"
|
|
"io"
|
|
"os"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
|
|
"github.com/libp2p/go-msgio/pbio"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/libp2p/go-libp2p/core/network"
|
|
|
|
"github.com/pion/datachannel"
|
|
"github.com/pion/sctp"
|
|
"github.com/pion/webrtc/v3"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type detachedChan struct {
|
|
rwc datachannel.ReadWriteCloser
|
|
dc *webrtc.DataChannel
|
|
}
|
|
|
|
func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) {
|
|
s := webrtc.SettingEngine{}
|
|
s.SetIncludeLoopbackCandidate(true)
|
|
s.DetachDataChannels()
|
|
api := webrtc.NewAPI(webrtc.WithSettingEngine(s))
|
|
|
|
offerPC, err := api.NewPeerConnection(webrtc.Configuration{})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { offerPC.Close() })
|
|
offerRWCChan := make(chan detachedChan, 1)
|
|
offerDC, err := offerPC.CreateDataChannel("data", nil)
|
|
require.NoError(t, err)
|
|
offerDC.OnOpen(func() {
|
|
rwc, err := offerDC.Detach()
|
|
require.NoError(t, err)
|
|
offerRWCChan <- detachedChan{rwc: rwc, dc: offerDC}
|
|
})
|
|
|
|
answerPC, err := api.NewPeerConnection(webrtc.Configuration{})
|
|
require.NoError(t, err)
|
|
|
|
answerChan := make(chan detachedChan, 1)
|
|
answerPC.OnDataChannel(func(dc *webrtc.DataChannel) {
|
|
dc.OnOpen(func() {
|
|
rwc, err := dc.Detach()
|
|
require.NoError(t, err)
|
|
answerChan <- detachedChan{rwc: rwc, dc: dc}
|
|
})
|
|
})
|
|
t.Cleanup(func() { answerPC.Close() })
|
|
|
|
// Set ICE Candidate handlers. As soon as a PeerConnection has gathered a candidate send it to the other peer
|
|
answerPC.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
|
if candidate != nil {
|
|
require.NoError(t, offerPC.AddICECandidate(candidate.ToJSON()))
|
|
}
|
|
})
|
|
offerPC.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
|
if candidate != nil {
|
|
require.NoError(t, answerPC.AddICECandidate(candidate.ToJSON()))
|
|
}
|
|
})
|
|
|
|
// Set the handler for Peer connection state
|
|
// This will notify you when the peer has connected/disconnected
|
|
offerPC.OnConnectionStateChange(func(s webrtc.PeerConnectionState) {
|
|
if s == webrtc.PeerConnectionStateFailed {
|
|
t.Log("peer connection failed on offerer")
|
|
}
|
|
})
|
|
|
|
// Set the handler for Peer connection state
|
|
// This will notify you when the peer has connected/disconnected
|
|
answerPC.OnConnectionStateChange(func(s webrtc.PeerConnectionState) {
|
|
if s == webrtc.PeerConnectionStateFailed {
|
|
t.Log("peer connection failed on answerer")
|
|
}
|
|
})
|
|
|
|
// Now, create an offer
|
|
offer, err := offerPC.CreateOffer(nil)
|
|
require.NoError(t, err)
|
|
require.NoError(t, answerPC.SetRemoteDescription(offer))
|
|
require.NoError(t, offerPC.SetLocalDescription(offer))
|
|
|
|
answer, err := answerPC.CreateAnswer(nil)
|
|
require.NoError(t, err)
|
|
require.NoError(t, offerPC.SetRemoteDescription(answer))
|
|
require.NoError(t, answerPC.SetLocalDescription(answer))
|
|
|
|
return <-answerChan, <-offerRWCChan
|
|
}
|
|
|
|
// assertDataChannelOpen checks if the datachannel is open.
|
|
// It sends empty messages on the data channel to check if the channel is still open.
|
|
// The control message reader goroutine depends on exclusive access to datachannel.Read
|
|
// so we have to depend on Write to determine whether the channel has been closed.
|
|
func assertDataChannelOpen(t *testing.T, dc *datachannel.DataChannel) {
|
|
t.Helper()
|
|
emptyMsg := &pb.Message{}
|
|
msg, err := proto.Marshal(emptyMsg)
|
|
if err != nil {
|
|
t.Fatal("unexpected mashalling error", err)
|
|
}
|
|
for i := 0; i < 3; i++ {
|
|
_, err := dc.Write(msg)
|
|
if err != nil {
|
|
t.Fatal("unexpected write err: ", err)
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
// assertDataChannelClosed checks if the datachannel is closed.
|
|
// It sends empty messages on the data channel to check if the channel has been closed.
|
|
// The control message reader goroutine depends on exclusive access to datachannel.Read
|
|
// so we have to depend on Write to determine whether the channel has been closed.
|
|
func assertDataChannelClosed(t *testing.T, dc *datachannel.DataChannel) {
|
|
t.Helper()
|
|
emptyMsg := &pb.Message{}
|
|
msg, err := proto.Marshal(emptyMsg)
|
|
if err != nil {
|
|
t.Fatal("unexpected mashalling error", err)
|
|
}
|
|
for i := 0; i < 5; i++ {
|
|
_, err := dc.Write(msg)
|
|
if err != nil {
|
|
if errors.Is(err, sctp.ErrStreamClosed) {
|
|
return
|
|
} else {
|
|
t.Fatal("unexpected write err: ", err)
|
|
}
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
func TestStreamSimpleReadWriteClose(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
var clientDone, serverDone atomic.Bool
|
|
clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) })
|
|
serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) })
|
|
|
|
// send a foobar from the client
|
|
n, err := clientStr.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
require.Equal(t, 6, n)
|
|
require.NoError(t, clientStr.CloseWrite())
|
|
// writing after closing should error
|
|
_, err = clientStr.Write([]byte("foobar"))
|
|
require.Error(t, err)
|
|
require.False(t, clientDone.Load())
|
|
|
|
// now read all the data on the server side
|
|
b, err := io.ReadAll(serverStr)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte("foobar"), b)
|
|
// reading again should give another io.EOF
|
|
n, err = serverStr.Read(make([]byte, 10))
|
|
require.Zero(t, n)
|
|
require.ErrorIs(t, err, io.EOF)
|
|
require.False(t, serverDone.Load())
|
|
|
|
// send something back
|
|
_, err = serverStr.Write([]byte("lorem ipsum"))
|
|
require.NoError(t, err)
|
|
require.NoError(t, serverStr.CloseWrite())
|
|
|
|
// and read it at the client
|
|
require.False(t, clientDone.Load())
|
|
b, err = io.ReadAll(clientStr)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte("lorem ipsum"), b)
|
|
|
|
// stream is only cleaned up on calling Close or Reset
|
|
clientStr.Close()
|
|
serverStr.Close()
|
|
require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond)
|
|
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
|
|
require.NoError(t, serverStr.Close())
|
|
require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond)
|
|
}
|
|
|
|
func TestStreamPartialReads(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
_, err := serverStr.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
require.NoError(t, serverStr.CloseWrite())
|
|
|
|
n, err := clientStr.Read([]byte{}) // empty read
|
|
require.NoError(t, err)
|
|
require.Zero(t, n)
|
|
b := make([]byte, 3)
|
|
n, err = clientStr.Read(b)
|
|
require.Equal(t, 3, n)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte("foo"), b)
|
|
b, err = io.ReadAll(clientStr)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte("bar"), b)
|
|
}
|
|
|
|
func TestStreamSkipEmptyFrames(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
for i := 0; i < 10; i++ {
|
|
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{}))
|
|
}
|
|
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{Message: []byte("foo")}))
|
|
for i := 0; i < 10; i++ {
|
|
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{}))
|
|
}
|
|
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{Message: []byte("bar")}))
|
|
for i := 0; i < 10; i++ {
|
|
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{}))
|
|
}
|
|
require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}))
|
|
|
|
var read []byte
|
|
var count int
|
|
for i := 0; i < 100; i++ {
|
|
b := make([]byte, 10)
|
|
count++
|
|
n, err := clientStr.Read(b)
|
|
read = append(read, b[:n]...)
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
require.NoError(t, err)
|
|
}
|
|
require.LessOrEqual(t, count, 3, "should've taken a maximum of 3 reads")
|
|
require.Equal(t, []byte("foobar"), read)
|
|
}
|
|
|
|
func TestStreamReadReturnsOnClose(t *testing.T) {
|
|
client, _ := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := clientStr.Read([]byte{0})
|
|
errChan <- err
|
|
}()
|
|
time.Sleep(100 * time.Millisecond) // give the Read call some time to hit the loop
|
|
require.NoError(t, clientStr.Close())
|
|
select {
|
|
case err := <-errChan:
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
_, err := clientStr.Read([]byte{0})
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
}
|
|
|
|
func TestStreamResets(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
var clientDone, serverDone atomic.Bool
|
|
clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) })
|
|
serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) })
|
|
|
|
// send a foobar from the client
|
|
_, err := clientStr.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
_, err = serverStr.Write([]byte("lorem ipsum"))
|
|
require.NoError(t, err)
|
|
require.NoError(t, clientStr.Reset()) // resetting resets both directions
|
|
require.True(t, clientDone.Load())
|
|
// attempting to write more data should result in a reset error
|
|
_, err = clientStr.Write([]byte("foobar"))
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
// read what the server sent
|
|
b, err := io.ReadAll(clientStr)
|
|
require.Empty(t, b)
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
|
|
// read the data on the server side
|
|
require.False(t, serverDone.Load())
|
|
b, err = io.ReadAll(serverStr)
|
|
require.Equal(t, []byte("foobar"), b)
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
require.Eventually(t, func() bool {
|
|
_, err := serverStr.Write([]byte("foobar"))
|
|
return errors.Is(err, network.ErrReset)
|
|
}, time.Second, 50*time.Millisecond)
|
|
serverStr.Close()
|
|
require.Eventually(t, func() bool {
|
|
return serverDone.Load()
|
|
}, time.Second, 50*time.Millisecond)
|
|
}
|
|
|
|
func TestStreamReadDeadlineAsync(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
timeout := 100 * time.Millisecond
|
|
if os.Getenv("CI") != "" {
|
|
timeout *= 5
|
|
}
|
|
start := time.Now()
|
|
clientStr.SetReadDeadline(start.Add(timeout))
|
|
_, err := clientStr.Read([]byte{0})
|
|
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
|
took := time.Since(start)
|
|
require.GreaterOrEqual(t, took, timeout)
|
|
require.LessOrEqual(t, took, timeout*3/2)
|
|
// repeated calls should return immediately
|
|
start = time.Now()
|
|
_, err = clientStr.Read([]byte{0})
|
|
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
|
require.LessOrEqual(t, time.Since(start), timeout/3)
|
|
// clear the deadline
|
|
clientStr.SetReadDeadline(time.Time{})
|
|
_, err = serverStr.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
_, err = clientStr.Read([]byte{0})
|
|
require.NoError(t, err)
|
|
require.LessOrEqual(t, time.Since(start), timeout/3)
|
|
}
|
|
|
|
func TestStreamWriteDeadlineAsync(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
_ = serverStr
|
|
|
|
b := make([]byte, 1024)
|
|
rand.Read(b)
|
|
start := time.Now()
|
|
timeout := 100 * time.Millisecond
|
|
if os.Getenv("CI") != "" {
|
|
timeout *= 5
|
|
}
|
|
clientStr.SetWriteDeadline(start.Add(timeout))
|
|
var hitDeadline bool
|
|
for i := 0; i < 2000; i++ {
|
|
if _, err := clientStr.Write(b); err != nil {
|
|
t.Logf("wrote %d kB", i)
|
|
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
|
hitDeadline = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, hitDeadline)
|
|
took := time.Since(start)
|
|
require.GreaterOrEqual(t, took, timeout)
|
|
require.LessOrEqual(t, took, timeout*3/2)
|
|
}
|
|
|
|
func TestStreamReadAfterClose(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
serverStr.Close()
|
|
b := make([]byte, 1)
|
|
_, err := clientStr.Read(b)
|
|
require.Equal(t, io.EOF, err)
|
|
_, err = clientStr.Read(nil)
|
|
require.Equal(t, io.EOF, err)
|
|
|
|
client, server = getDetachedDataChannels(t)
|
|
|
|
clientStr = newStream(client.dc, client.rwc, func() {})
|
|
serverStr = newStream(server.dc, server.rwc, func() {})
|
|
|
|
serverStr.Reset()
|
|
b = make([]byte, 1)
|
|
_, err = clientStr.Read(b)
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
_, err = clientStr.Read(nil)
|
|
require.ErrorIs(t, err, network.ErrReset)
|
|
}
|
|
|
|
func TestStreamCloseAfterFINACK(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
done := make(chan bool, 1)
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
go func() {
|
|
err := clientStr.Close()
|
|
assert.NoError(t, err)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(200 * time.Millisecond):
|
|
t.Fatalf("Close should signal OnDone immediately")
|
|
}
|
|
|
|
// Reading FIN_ACK on server should trigger data channel close on the client
|
|
b := make([]byte, 1)
|
|
_, err := serverStr.Read(b)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, io.EOF)
|
|
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
|
|
}
|
|
|
|
// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half
|
|
// of the stream is closed.
|
|
func TestStreamFinAckAfterStopSending(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
done := make(chan bool, 1)
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
go func() {
|
|
clientStr.CloseRead()
|
|
clientStr.Write([]byte("hello world"))
|
|
done <- true
|
|
err := clientStr.Close()
|
|
assert.NoError(t, err)
|
|
}()
|
|
<-done
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Errorf("Close should signal onDone immediately")
|
|
}
|
|
|
|
// serverStr has write half closed and read half open
|
|
// serverStr should still send FIN_ACK
|
|
b := make([]byte, 24)
|
|
_, err := serverStr.Read(b)
|
|
require.NoError(t, err)
|
|
serverStr.Close() // Sends stop_sending, fin
|
|
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
|
|
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
|
|
}
|
|
|
|
func TestStreamConcurrentClose(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
start := make(chan bool, 2)
|
|
done := make(chan bool, 2)
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
|
|
serverStr := newStream(server.dc, server.rwc, func() { done <- true })
|
|
|
|
go func() {
|
|
start <- true
|
|
clientStr.Close()
|
|
}()
|
|
go func() {
|
|
start <- true
|
|
serverStr.Close()
|
|
}()
|
|
<-start
|
|
<-start
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatalf("concurrent close should succeed quickly")
|
|
}
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatalf("concurrent close should succeed quickly")
|
|
}
|
|
|
|
// Wait for FIN_ACK AND datachannel close
|
|
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
|
|
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
|
|
|
|
}
|
|
|
|
func TestStreamResetAfterClose(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
done := make(chan bool, 2)
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
|
|
clientStr.Close()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Fatalf("Close should run cleanup immediately")
|
|
}
|
|
// The server data channel should still be open
|
|
assertDataChannelOpen(t, server.rwc.(*datachannel.DataChannel))
|
|
clientStr.Reset()
|
|
// Reset closes the datachannels
|
|
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
|
|
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
|
|
select {
|
|
case <-done:
|
|
t.Fatalf("onDone should not be called twice")
|
|
case <-time.After(50 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestStreamDataChannelCloseOnFINACK(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
done := make(chan bool, 1)
|
|
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
|
|
|
|
clientStr.Close()
|
|
|
|
select {
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Fatalf("Close should run cleanup immediately")
|
|
case <-done:
|
|
}
|
|
|
|
// sending FIN_ACK closes the datachannel
|
|
serverWriter := pbio.NewDelimitedWriter(server.rwc)
|
|
err := serverWriter.WriteMsg(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()})
|
|
require.NoError(t, err)
|
|
|
|
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
|
|
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
|
|
}
|
|
|
|
func TestStreamChunking(t *testing.T) {
|
|
client, server := getDetachedDataChannels(t)
|
|
|
|
clientStr := newStream(client.dc, client.rwc, func() {})
|
|
serverStr := newStream(server.dc, server.rwc, func() {})
|
|
|
|
const N = (16 << 10) + 1000
|
|
go func() {
|
|
data := make([]byte, N)
|
|
_, err := clientStr.Write(data)
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
data := make([]byte, N)
|
|
n, err := serverStr.Read(data)
|
|
require.NoError(t, err)
|
|
require.LessOrEqual(t, n, 16<<10)
|
|
|
|
nn, err := serverStr.Read(data)
|
|
require.NoError(t, err)
|
|
require.Equal(t, nn+n, N)
|
|
}
|