mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2025-01-25 23:25:17 +00:00
784 lines
24 KiB
Go
784 lines
24 KiB
Go
|
package libp2pquic
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
mrand "math/rand"
|
||
|
"net"
|
||
|
"sync/atomic"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
ic "github.com/libp2p/go-libp2p/core/crypto"
|
||
|
"github.com/libp2p/go-libp2p/core/network"
|
||
|
mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks"
|
||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||
|
tpt "github.com/libp2p/go-libp2p/core/transport"
|
||
|
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
||
|
|
||
|
"github.com/golang/mock/gomock"
|
||
|
ma "github.com/multiformats/go-multiaddr"
|
||
|
"github.com/quic-go/quic-go"
|
||
|
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
//go:generate sh -c "go run github.com/golang/mock/mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && go run golang.org/x/tools/cmd/goimports -w mock_connection_gater_test.go"
|
||
|
|
||
|
type connTestCase struct {
|
||
|
Name string
|
||
|
Options []quicreuse.Option
|
||
|
}
|
||
|
|
||
|
var connTestCases = []*connTestCase{
|
||
|
{"reuseport_on", []quicreuse.Option{quicreuse.DisableDraft29()}},
|
||
|
{"reuseport_off", []quicreuse.Option{quicreuse.DisableReuseport(), quicreuse.DisableDraft29()}},
|
||
|
}
|
||
|
|
||
|
func createPeer(t *testing.T) (peer.ID, ic.PrivKey) {
|
||
|
var priv ic.PrivKey
|
||
|
var err error
|
||
|
switch mrand.Int() % 4 {
|
||
|
case 0:
|
||
|
priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader)
|
||
|
case 1:
|
||
|
priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader)
|
||
|
case 2:
|
||
|
priv, _, err = ic.GenerateEd25519Key(rand.Reader)
|
||
|
case 3:
|
||
|
priv, _, err = ic.GenerateSecp256k1Key(rand.Reader)
|
||
|
}
|
||
|
require.NoError(t, err)
|
||
|
id, err := peer.IDFromPrivateKey(priv)
|
||
|
require.NoError(t, err)
|
||
|
t.Logf("using a %s key: %s", priv.Type(), id.Pretty())
|
||
|
return id, priv
|
||
|
}
|
||
|
|
||
|
func runServer(t *testing.T, tr tpt.Transport, addr string) tpt.Listener {
|
||
|
t.Helper()
|
||
|
|
||
|
ln, err := tr.Listen(ma.StringCast(addr))
|
||
|
require.NoError(t, err)
|
||
|
return ln
|
||
|
}
|
||
|
|
||
|
func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManager {
|
||
|
t.Helper()
|
||
|
cm, err := quicreuse.NewConnManager([32]byte{}, opts...)
|
||
|
require.NoError(t, err)
|
||
|
t.Cleanup(func() { cm.Close() })
|
||
|
return cm
|
||
|
}
|
||
|
|
||
|
func TestHandshake(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testHandshake(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testHandshake(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
clientID, clientKey := createPeer(t)
|
||
|
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
|
||
|
handshake := func(t *testing.T, ln tpt.Listener) {
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
require.NoError(t, err)
|
||
|
defer conn.Close()
|
||
|
serverConn, err := ln.Accept()
|
||
|
require.NoError(t, err)
|
||
|
defer serverConn.Close()
|
||
|
|
||
|
require.Equal(t, conn.LocalPeer(), clientID)
|
||
|
require.Equal(t, conn.RemotePeer(), serverID)
|
||
|
require.True(t, conn.RemotePublicKey().Equals(serverKey.GetPublic()), "remote public key doesn't match")
|
||
|
|
||
|
require.Equal(t, serverConn.LocalPeer(), serverID)
|
||
|
require.Equal(t, serverConn.RemotePeer(), clientID)
|
||
|
require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "remote public key doesn't match")
|
||
|
}
|
||
|
|
||
|
t.Run("on IPv4", func(t *testing.T) {
|
||
|
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
defer ln.Close()
|
||
|
handshake(t, ln)
|
||
|
})
|
||
|
|
||
|
t.Run("on IPv6", func(t *testing.T) {
|
||
|
ln := runServer(t, serverTransport, "/ip6/::1/udp/0/quic-v1")
|
||
|
defer ln.Close()
|
||
|
handshake(t, ln)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestResourceManagerSuccess(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testResourceManagerSuccess(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testResourceManagerSuccess(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
clientID, clientKey := createPeer(t)
|
||
|
|
||
|
ctrl := gomock.NewController(t)
|
||
|
defer ctrl.Finish()
|
||
|
|
||
|
serverRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, serverRcmgr)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"))
|
||
|
require.NoError(t, err)
|
||
|
defer ln.Close()
|
||
|
|
||
|
clientRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, clientRcmgr)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
|
||
|
connChan := make(chan tpt.CapableConn)
|
||
|
serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||
|
go func() {
|
||
|
serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Not(ln.Multiaddr())).Return(serverConnScope, nil)
|
||
|
serverConnScope.EXPECT().SetPeer(clientID)
|
||
|
serverConn, err := ln.Accept()
|
||
|
require.NoError(t, err)
|
||
|
connChan <- serverConn
|
||
|
}()
|
||
|
|
||
|
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||
|
clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false, ln.Multiaddr()).Return(connScope, nil)
|
||
|
connScope.EXPECT().SetPeer(serverID)
|
||
|
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
require.NoError(t, err)
|
||
|
serverConn := <-connChan
|
||
|
t.Log("received conn")
|
||
|
connScope.EXPECT().Done().MinTimes(1) // for dialed connections, we might call Done multiple times
|
||
|
conn.Close()
|
||
|
serverConnScope.EXPECT().Done()
|
||
|
serverConn.Close()
|
||
|
}
|
||
|
|
||
|
func TestResourceManagerDialDenied(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testResourceManagerDialDenied(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testResourceManagerDialDenied(t *testing.T, tc *connTestCase) {
|
||
|
_, clientKey := createPeer(t)
|
||
|
ctrl := gomock.NewController(t)
|
||
|
defer ctrl.Finish()
|
||
|
|
||
|
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, rcmgr)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
|
||
|
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||
|
target := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1")
|
||
|
|
||
|
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, target).Return(connScope, nil)
|
||
|
rerr := errors.New("nope")
|
||
|
p := peer.ID("server")
|
||
|
connScope.EXPECT().SetPeer(p).Return(rerr)
|
||
|
connScope.EXPECT().Done()
|
||
|
|
||
|
_, err = clientTransport.Dial(context.Background(), target, p)
|
||
|
require.ErrorIs(t, err, rerr)
|
||
|
|
||
|
}
|
||
|
|
||
|
func TestResourceManagerAcceptDenied(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testResourceManagerAcceptDenied(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testResourceManagerAcceptDenied(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
clientID, clientKey := createPeer(t)
|
||
|
ctrl := gomock.NewController(t)
|
||
|
defer ctrl.Finish()
|
||
|
|
||
|
clientRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, clientRcmgr)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
|
||
|
serverRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||
|
serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||
|
rerr := errors.New("denied")
|
||
|
gomock.InOrder(
|
||
|
serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(serverConnScope, nil),
|
||
|
serverConnScope.EXPECT().SetPeer(clientID).Return(rerr),
|
||
|
serverConnScope.EXPECT().Done(),
|
||
|
)
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, serverRcmgr)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"))
|
||
|
require.NoError(t, err)
|
||
|
defer ln.Close()
|
||
|
connChan := make(chan tpt.CapableConn)
|
||
|
go func() {
|
||
|
ln.Accept()
|
||
|
close(connChan)
|
||
|
}()
|
||
|
|
||
|
clientConnScope := mocknetwork.NewMockConnManagementScope(ctrl)
|
||
|
clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false, ln.Multiaddr()).Return(clientConnScope, nil)
|
||
|
clientConnScope.EXPECT().SetPeer(serverID)
|
||
|
// In rare instances, the connection gating error will already occur on Dial.
|
||
|
// In that case, Done is called on the connection scope.
|
||
|
clientConnScope.EXPECT().Done().MaxTimes(1)
|
||
|
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
// In rare instances, the connection gating error will already occur on Dial.
|
||
|
if err == nil {
|
||
|
_, err = conn.AcceptStream()
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
select {
|
||
|
case <-connChan:
|
||
|
t.Fatal("didn't expect to accept a connection")
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreams(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testStreams(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testStreams(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
_, clientKey := createPeer(t)
|
||
|
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
defer ln.Close()
|
||
|
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
require.NoError(t, err)
|
||
|
defer conn.Close()
|
||
|
serverConn, err := ln.Accept()
|
||
|
require.NoError(t, err)
|
||
|
defer serverConn.Close()
|
||
|
|
||
|
str, err := conn.OpenStream(context.Background())
|
||
|
require.NoError(t, err)
|
||
|
_, err = str.Write([]byte("foobar"))
|
||
|
require.NoError(t, err)
|
||
|
str.Close()
|
||
|
sstr, err := serverConn.AcceptStream()
|
||
|
require.NoError(t, err)
|
||
|
data, err := io.ReadAll(sstr)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, data, []byte("foobar"))
|
||
|
}
|
||
|
|
||
|
func TestHandshakeFailPeerIDMismatch(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testHandshakeFailPeerIDMismatch(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testHandshakeFailPeerIDMismatch(t *testing.T, tc *connTestCase) {
|
||
|
_, serverKey := createPeer(t)
|
||
|
_, clientKey := createPeer(t)
|
||
|
thirdPartyID, _ := createPeer(t)
|
||
|
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
// dial, but expect the wrong peer ID
|
||
|
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
|
||
|
require.Error(t, err)
|
||
|
require.Contains(t, err.Error(), "CRYPTO_ERROR")
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
|
||
|
acceptErr := make(chan error)
|
||
|
go func() {
|
||
|
_, err := ln.Accept()
|
||
|
acceptErr <- err
|
||
|
}()
|
||
|
|
||
|
select {
|
||
|
case <-acceptErr:
|
||
|
t.Fatal("didn't expect Accept to return before being closed")
|
||
|
case <-time.After(100 * time.Millisecond):
|
||
|
}
|
||
|
|
||
|
require.NoError(t, ln.Close())
|
||
|
require.Error(t, <-acceptErr)
|
||
|
}
|
||
|
|
||
|
func TestConnectionGating(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testConnectionGating(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testConnectionGating(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
_, clientKey := createPeer(t)
|
||
|
|
||
|
mockCtrl := gomock.NewController(t)
|
||
|
defer mockCtrl.Finish()
|
||
|
cg := NewMockConnectionGater(mockCtrl)
|
||
|
|
||
|
t.Run("accepted connections", func(t *testing.T) {
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, cg, nil)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
require.NoError(t, err)
|
||
|
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
defer ln.Close()
|
||
|
|
||
|
cg.EXPECT().InterceptAccept(gomock.Any())
|
||
|
|
||
|
accepted := make(chan struct{})
|
||
|
go func() {
|
||
|
defer close(accepted)
|
||
|
_, err := ln.Accept()
|
||
|
require.NoError(t, err)
|
||
|
}()
|
||
|
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
// make sure that connection attempts fails
|
||
|
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
// In rare instances, the connection gating error will already occur on Dial.
|
||
|
// In most cases, it will be returned by AcceptStream.
|
||
|
if err == nil {
|
||
|
_, err = conn.AcceptStream()
|
||
|
}
|
||
|
require.Contains(t, err.Error(), "connection gated")
|
||
|
|
||
|
// now allow the address and make sure the connection goes through
|
||
|
cg.EXPECT().InterceptAccept(gomock.Any()).Return(true)
|
||
|
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||
|
conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
require.NoError(t, err)
|
||
|
defer conn.Close()
|
||
|
require.Eventually(t, func() bool {
|
||
|
select {
|
||
|
case <-accepted:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}, time.Second, 10*time.Millisecond)
|
||
|
})
|
||
|
|
||
|
t.Run("secured connections", func(t *testing.T) {
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
defer ln.Close()
|
||
|
|
||
|
cg := NewMockConnectionGater(mockCtrl)
|
||
|
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any())
|
||
|
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, cg, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
|
||
|
// make sure that connection attempts fails
|
||
|
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
require.Error(t, err)
|
||
|
require.Contains(t, err.Error(), "connection gated")
|
||
|
|
||
|
// now allow the peerId and make sure the connection goes through
|
||
|
cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
||
|
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||
|
require.NoError(t, err)
|
||
|
conn.Close()
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestDialTwo(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testDialTwo(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testDialTwo(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
_, clientKey := createPeer(t)
|
||
|
serverID2, serverKey2 := createPeer(t)
|
||
|
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln1 := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
defer ln1.Close()
|
||
|
serverTransport2, err := NewTransport(serverKey2, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport2.(io.Closer).Close()
|
||
|
ln2 := runServer(t, serverTransport2, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
defer ln2.Close()
|
||
|
|
||
|
data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB
|
||
|
// wait for both servers to accept a connection
|
||
|
// then send some data
|
||
|
go func() {
|
||
|
serverConn1, err := ln1.Accept()
|
||
|
require.NoError(t, err)
|
||
|
serverConn2, err := ln2.Accept()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
for _, c := range []tpt.CapableConn{serverConn1, serverConn2} {
|
||
|
go func(conn tpt.CapableConn) {
|
||
|
str, err := conn.OpenStream(context.Background())
|
||
|
require.NoError(t, err)
|
||
|
defer str.Close()
|
||
|
_, err = str.Write(data)
|
||
|
require.NoError(t, err)
|
||
|
}(c)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
|
||
|
require.NoError(t, err)
|
||
|
defer c1.Close()
|
||
|
c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2)
|
||
|
require.NoError(t, err)
|
||
|
defer c2.Close()
|
||
|
|
||
|
done := make(chan struct{}, 2)
|
||
|
// receive the data on both connections at the same time
|
||
|
for _, c := range []tpt.CapableConn{c1, c2} {
|
||
|
go func(conn tpt.CapableConn) {
|
||
|
str, err := conn.AcceptStream()
|
||
|
require.NoError(t, err)
|
||
|
str.CloseWrite()
|
||
|
d, err := io.ReadAll(str)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, d, data)
|
||
|
done <- struct{}{}
|
||
|
}(c)
|
||
|
}
|
||
|
|
||
|
for i := 0; i < 2; i++ {
|
||
|
require.Eventually(t, func() bool {
|
||
|
select {
|
||
|
case <-done:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}, 15*time.Second, 50*time.Millisecond)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStatelessReset(t *testing.T) {
|
||
|
for _, tc := range connTestCases {
|
||
|
t.Run(tc.Name, func(t *testing.T) {
|
||
|
testStatelessReset(t, tc)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testStatelessReset(t *testing.T, tc *connTestCase) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
_, clientKey := createPeer(t)
|
||
|
|
||
|
serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer serverTransport.(io.Closer).Close()
|
||
|
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
|
||
|
var drop uint32
|
||
|
dropCallback := func(quicproxy.Direction, []byte) bool { return atomic.LoadUint32(&drop) > 0 }
|
||
|
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||
|
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||
|
DropPacket: dropCallback,
|
||
|
})
|
||
|
require.NoError(t, err)
|
||
|
proxyLocalAddr := proxy.LocalAddr()
|
||
|
|
||
|
// establish a connection
|
||
|
clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer clientTransport.(io.Closer).Close()
|
||
|
proxyAddr, err := quicreuse.ToQuicMultiaddr(proxy.LocalAddr(), quic.Version1)
|
||
|
require.NoError(t, err)
|
||
|
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
|
||
|
require.NoError(t, err)
|
||
|
connChan := make(chan tpt.CapableConn)
|
||
|
go func() {
|
||
|
conn, err := ln.Accept()
|
||
|
require.NoError(t, err)
|
||
|
str, err := conn.OpenStream(context.Background())
|
||
|
require.NoError(t, err)
|
||
|
_, err = conn.LocalMultiaddr().ValueForProtocol(ma.P_QUIC_V1)
|
||
|
require.NoError(t, err)
|
||
|
str.Write([]byte("foobar"))
|
||
|
connChan <- conn
|
||
|
}()
|
||
|
|
||
|
str, err := conn.AcceptStream()
|
||
|
require.NoError(t, err)
|
||
|
_, err = str.Read(make([]byte, 6))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// Stop forwarding packets and close the server.
|
||
|
// This prevents the CONNECTION_CLOSE from reaching the client.
|
||
|
atomic.StoreUint32(&drop, 1)
|
||
|
ln.Close()
|
||
|
(<-connChan).Close()
|
||
|
proxy.Close()
|
||
|
|
||
|
// Start another listener (on a different port).
|
||
|
ln, err = serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"))
|
||
|
require.NoError(t, err)
|
||
|
defer ln.Close()
|
||
|
// Now that the new server is up, re-enable packet forwarding.
|
||
|
atomic.StoreUint32(&drop, 0)
|
||
|
|
||
|
// Recreate the proxy, such that its client-facing port stays constant.
|
||
|
proxy, err = quicproxy.NewQuicProxy(proxyLocalAddr.String(), &quicproxy.Opts{
|
||
|
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||
|
DropPacket: dropCallback,
|
||
|
})
|
||
|
require.NoError(t, err)
|
||
|
defer proxy.Close()
|
||
|
|
||
|
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
|
||
|
// The new server doesn't have any state for the previously established connection.
|
||
|
// We expect it to send a stateless reset.
|
||
|
_, rerr := str.Write([]byte("Lorem ipsum dolor sit amet."))
|
||
|
if rerr == nil {
|
||
|
_, rerr = str.Read([]byte{0, 0})
|
||
|
}
|
||
|
require.Error(t, rerr)
|
||
|
var statelessResetErr *quic.StatelessResetError
|
||
|
require.ErrorAs(t, rerr, &statelessResetErr)
|
||
|
}
|
||
|
|
||
|
// Hole punching is only expected to work with reuseport enabled.
|
||
|
// We don't need to test `DisableReuseport` option.
|
||
|
func TestHolePunching(t *testing.T) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
clientID, clientKey := createPeer(t)
|
||
|
|
||
|
t1, err := NewTransport(serverKey, newConnManager(t), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer t1.(io.Closer).Close()
|
||
|
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
require.NoError(t, err)
|
||
|
ln1, err := t1.Listen(laddr)
|
||
|
require.NoError(t, err)
|
||
|
done1 := make(chan struct{})
|
||
|
go func() {
|
||
|
defer close(done1)
|
||
|
_, err := ln1.Accept()
|
||
|
require.Error(t, err, "didn't expect to accept any connections")
|
||
|
}()
|
||
|
|
||
|
t2, err := NewTransport(clientKey, newConnManager(t), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer t2.(io.Closer).Close()
|
||
|
ln2, err := t2.Listen(laddr)
|
||
|
require.NoError(t, err)
|
||
|
done2 := make(chan struct{})
|
||
|
go func() {
|
||
|
defer close(done2)
|
||
|
_, err := ln2.Accept()
|
||
|
require.Error(t, err, "didn't expect to accept any connections")
|
||
|
}()
|
||
|
connChan := make(chan tpt.CapableConn)
|
||
|
go func() {
|
||
|
conn, err := t2.Dial(
|
||
|
network.WithSimultaneousConnect(context.Background(), false, ""),
|
||
|
ln1.Multiaddr(),
|
||
|
serverID,
|
||
|
)
|
||
|
require.NoError(t, err)
|
||
|
connChan <- conn
|
||
|
}()
|
||
|
// Make sure the server role (the dial on t2) has progressed far enough.
|
||
|
// If it hasn't created the hole punch map entry, the connection will be accepted as a regular connection,
|
||
|
// which would make this test fail.
|
||
|
require.Eventually(t, func() bool {
|
||
|
tr := t2.(*transport)
|
||
|
tr.holePunchingMx.Lock()
|
||
|
defer tr.holePunchingMx.Unlock()
|
||
|
return len(tr.holePunching) > 0
|
||
|
}, time.Second, 10*time.Millisecond)
|
||
|
|
||
|
conn1, err := t1.Dial(
|
||
|
network.WithSimultaneousConnect(context.Background(), true, ""),
|
||
|
ln2.Multiaddr(),
|
||
|
clientID,
|
||
|
)
|
||
|
require.NoError(t, err)
|
||
|
defer conn1.Close()
|
||
|
require.Equal(t, conn1.RemotePeer(), clientID)
|
||
|
var conn2 tpt.CapableConn
|
||
|
require.Eventually(t, func() bool {
|
||
|
select {
|
||
|
case conn2 = <-connChan:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}, time.Second, 10*time.Millisecond)
|
||
|
defer conn2.Close()
|
||
|
require.Equal(t, conn2.RemotePeer(), serverID)
|
||
|
ln1.Close()
|
||
|
ln2.Close()
|
||
|
<-done1
|
||
|
<-done2
|
||
|
}
|
||
|
|
||
|
func TestGetErrorWhenListeningWithDraft29WhenDisabled(t *testing.T) {
|
||
|
_, serverKey := createPeer(t)
|
||
|
|
||
|
t1, err := NewTransport(serverKey, newConnManager(t, quicreuse.DisableDraft29()), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer t1.(io.Closer).Close()
|
||
|
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
|
||
|
require.NoError(t, err)
|
||
|
_, err = t1.Listen(laddr)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestClientCanDialDifferentQUICVersions(t *testing.T) {
|
||
|
type testCase struct {
|
||
|
name string
|
||
|
serverDisablesDraft29 bool
|
||
|
}
|
||
|
|
||
|
testCases := []testCase{
|
||
|
{
|
||
|
name: "Client dials quic-v1 on a quic-v1 only server",
|
||
|
serverDisablesDraft29: true,
|
||
|
},
|
||
|
{
|
||
|
name: "Client dials both draft 29 and v1 on server that supports both",
|
||
|
serverDisablesDraft29: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
serverID, serverKey := createPeer(t)
|
||
|
_, clientKey := createPeer(t)
|
||
|
|
||
|
var serverOpts []quicreuse.Option
|
||
|
if tc.serverDisablesDraft29 {
|
||
|
serverOpts = append(serverOpts, quicreuse.DisableDraft29())
|
||
|
}
|
||
|
|
||
|
t1, err := NewTransport(serverKey, newConnManager(t, serverOpts...), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer t1.(io.Closer).Close()
|
||
|
laddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1")
|
||
|
ln1, err := t1.Listen(laddr)
|
||
|
require.NoError(t, err)
|
||
|
t.Cleanup(func() { ln1.Close() })
|
||
|
|
||
|
mas := []ma.Multiaddr{ln1.Multiaddr()}
|
||
|
var ln2 tpt.Listener
|
||
|
if !tc.serverDisablesDraft29 {
|
||
|
laddrDraft29 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic")
|
||
|
ln2, err = t1.Listen(laddrDraft29)
|
||
|
require.NoError(t, err)
|
||
|
t.Cleanup(func() { ln2.Close() })
|
||
|
mas = append(mas, ln2.Multiaddr())
|
||
|
}
|
||
|
|
||
|
t2, err := NewTransport(clientKey, newConnManager(t), nil, nil, nil)
|
||
|
require.NoError(t, err)
|
||
|
defer t2.(io.Closer).Close()
|
||
|
|
||
|
ctx := context.Background()
|
||
|
|
||
|
for _, a := range mas {
|
||
|
_, v, err := quicreuse.FromQuicMultiaddr(a)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
done := make(chan struct{})
|
||
|
go func() {
|
||
|
defer close(done)
|
||
|
var conn tpt.CapableConn
|
||
|
var err error
|
||
|
if v == quic.Version1 {
|
||
|
conn, err = ln1.Accept()
|
||
|
} else if v == quic.VersionDraft29 {
|
||
|
conn, err = ln2.Accept()
|
||
|
} else {
|
||
|
panic("unexpected version")
|
||
|
}
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, versionConnLocal, err := quicreuse.FromQuicMultiaddr(conn.LocalMultiaddr())
|
||
|
require.NoError(t, err)
|
||
|
_, versionConnRemote, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr())
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, v, versionConnLocal)
|
||
|
require.Equal(t, v, versionConnRemote)
|
||
|
}()
|
||
|
|
||
|
conn, err := t2.Dial(ctx, a, serverID)
|
||
|
require.NoError(t, err)
|
||
|
_, versionConnLocal, err := quicreuse.FromQuicMultiaddr(conn.LocalMultiaddr())
|
||
|
require.NoError(t, err)
|
||
|
_, versionConnRemote, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr())
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, v, versionConnLocal)
|
||
|
require.Equal(t, v, versionConnRemote)
|
||
|
|
||
|
<-done
|
||
|
conn.Close()
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|