mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2025-01-24 06:36:13 +00:00
897 lines
25 KiB
Go
897 lines
25 KiB
Go
|
package basichost
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"reflect"
|
||
|
"sort"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/libp2p/go-libp2p/core/event"
|
||
|
"github.com/libp2p/go-libp2p/core/host"
|
||
|
"github.com/libp2p/go-libp2p/core/network"
|
||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||
|
"github.com/libp2p/go-libp2p/core/peerstore"
|
||
|
"github.com/libp2p/go-libp2p/core/protocol"
|
||
|
"github.com/libp2p/go-libp2p/core/record"
|
||
|
"github.com/libp2p/go-libp2p/p2p/host/autonat"
|
||
|
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
|
||
|
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
|
||
|
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
|
||
|
|
||
|
ma "github.com/multiformats/go-multiaddr"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestHostDoubleClose(t *testing.T) {
|
||
|
h1, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
h1.Close()
|
||
|
h1.Close()
|
||
|
}
|
||
|
|
||
|
func TestHostSimple(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
h1, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h1.Close()
|
||
|
h1.Start()
|
||
|
h2, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h2.Close()
|
||
|
h2.Start()
|
||
|
|
||
|
h2pi := h2.Peerstore().PeerInfo(h2.ID())
|
||
|
require.NoError(t, h1.Connect(ctx, h2pi))
|
||
|
|
||
|
piper, pipew := io.Pipe()
|
||
|
h2.SetStreamHandler(protocol.TestingID, func(s network.Stream) {
|
||
|
defer s.Close()
|
||
|
w := io.MultiWriter(s, pipew)
|
||
|
io.Copy(w, s) // mirror everything
|
||
|
})
|
||
|
|
||
|
s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// write to the stream
|
||
|
buf1 := []byte("abcdefghijkl")
|
||
|
_, err = s.Write(buf1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// get it from the stream (echoed)
|
||
|
buf2 := make([]byte, len(buf1))
|
||
|
_, err = io.ReadFull(s, buf2)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, buf1, buf2)
|
||
|
|
||
|
// get it from the pipe (tee)
|
||
|
buf3 := make([]byte, len(buf1))
|
||
|
_, err = io.ReadFull(piper, buf3)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, buf1, buf3)
|
||
|
}
|
||
|
|
||
|
func TestMultipleClose(t *testing.T) {
|
||
|
h, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.NoError(t, h.Close())
|
||
|
require.NoError(t, h.Close())
|
||
|
require.NoError(t, h.Close())
|
||
|
}
|
||
|
|
||
|
func TestSignedPeerRecordWithNoListenAddrs(t *testing.T) {
|
||
|
h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h.Close()
|
||
|
h.Start()
|
||
|
|
||
|
require.Empty(t, h.Addrs(), "expected no listen addrs")
|
||
|
// now add a listen addr
|
||
|
require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")))
|
||
|
require.NotEmpty(t, h.Addrs(), "expected at least 1 listen addr")
|
||
|
|
||
|
cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore())
|
||
|
if !ok {
|
||
|
t.Fatalf("peerstore doesn't support certified addrs")
|
||
|
}
|
||
|
// the signed record with the new addr is added async
|
||
|
var env *record.Envelope
|
||
|
require.Eventually(t, func() bool {
|
||
|
env = cab.GetPeerRecord(h.ID())
|
||
|
return env != nil
|
||
|
}, 500*time.Millisecond, 10*time.Millisecond)
|
||
|
rec, err := env.Record()
|
||
|
require.NoError(t, err)
|
||
|
require.NotEmpty(t, rec.(*peer.PeerRecord).Addrs)
|
||
|
}
|
||
|
|
||
|
func TestProtocolHandlerEvents(t *testing.T) {
|
||
|
h, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h.Close()
|
||
|
|
||
|
sub, err := h.EventBus().Subscribe(&event.EvtLocalProtocolsUpdated{}, eventbus.BufSize(16))
|
||
|
require.NoError(t, err)
|
||
|
defer sub.Close()
|
||
|
|
||
|
// the identify service adds new protocol handlers shortly after the host
|
||
|
// starts. this helps us filter those events out, since they're unrelated
|
||
|
// to the test.
|
||
|
isIdentify := func(evt event.EvtLocalProtocolsUpdated) bool {
|
||
|
for _, p := range evt.Added {
|
||
|
if p == identify.ID || p == identify.IDPush {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
nextEvent := func() event.EvtLocalProtocolsUpdated {
|
||
|
for {
|
||
|
select {
|
||
|
case evt := <-sub.Out():
|
||
|
next := evt.(event.EvtLocalProtocolsUpdated)
|
||
|
if isIdentify(next) {
|
||
|
continue
|
||
|
}
|
||
|
return next
|
||
|
case <-time.After(5 * time.Second):
|
||
|
t.Fatal("event not received in 5 seconds")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
assert := func(added, removed []protocol.ID) {
|
||
|
next := nextEvent()
|
||
|
if !reflect.DeepEqual(added, next.Added) {
|
||
|
t.Errorf("expected added: %v; received: %v", added, next.Added)
|
||
|
}
|
||
|
if !reflect.DeepEqual(removed, next.Removed) {
|
||
|
t.Errorf("expected removed: %v; received: %v", removed, next.Removed)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
h.SetStreamHandler(protocol.TestingID, func(s network.Stream) {})
|
||
|
assert([]protocol.ID{protocol.TestingID}, nil)
|
||
|
h.SetStreamHandler("foo", func(s network.Stream) {})
|
||
|
assert([]protocol.ID{"foo"}, nil)
|
||
|
h.RemoveStreamHandler(protocol.TestingID)
|
||
|
assert(nil, []protocol.ID{protocol.TestingID})
|
||
|
}
|
||
|
|
||
|
func TestHostAddrsFactory(t *testing.T) {
|
||
|
maddr := ma.StringCast("/ip4/1.2.3.4/tcp/1234")
|
||
|
addrsFactory := func(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||
|
return []ma.Multiaddr{maddr}
|
||
|
}
|
||
|
|
||
|
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{AddrsFactory: addrsFactory})
|
||
|
require.NoError(t, err)
|
||
|
defer h.Close()
|
||
|
|
||
|
addrs := h.Addrs()
|
||
|
if len(addrs) != 1 {
|
||
|
t.Fatalf("expected 1 addr, got %+v", addrs)
|
||
|
}
|
||
|
if !addrs[0].Equal(maddr) {
|
||
|
t.Fatalf("expected %s, got %s", maddr.String(), addrs[0].String())
|
||
|
}
|
||
|
|
||
|
autoNat, err := autonat.New(h, autonat.WithReachability(network.ReachabilityPublic))
|
||
|
if err != nil {
|
||
|
t.Fatalf("should be able to attach autonat: %v", err)
|
||
|
}
|
||
|
h.SetAutoNat(autoNat)
|
||
|
addrs = h.Addrs()
|
||
|
if len(addrs) != 1 {
|
||
|
t.Fatalf("didn't expect change in returned addresses.")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestLocalIPChangesWhenListenAddrChanges(t *testing.T) {
|
||
|
// no listen addrs
|
||
|
h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
|
||
|
require.NoError(t, err)
|
||
|
h.Start()
|
||
|
defer h.Close()
|
||
|
|
||
|
h.addrMu.Lock()
|
||
|
h.filteredInterfaceAddrs = nil
|
||
|
h.allInterfaceAddrs = nil
|
||
|
h.addrMu.Unlock()
|
||
|
|
||
|
// change listen addrs and verify local IP addr is not nil again
|
||
|
require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")))
|
||
|
h.SignalAddressChange()
|
||
|
time.Sleep(1 * time.Second)
|
||
|
|
||
|
h.addrMu.RLock()
|
||
|
defer h.addrMu.RUnlock()
|
||
|
require.NotEmpty(t, h.filteredInterfaceAddrs)
|
||
|
require.NotEmpty(t, h.allInterfaceAddrs)
|
||
|
}
|
||
|
|
||
|
func TestAllAddrs(t *testing.T) {
|
||
|
// no listen addrs
|
||
|
h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h.Close()
|
||
|
require.Nil(t, h.AllAddrs())
|
||
|
|
||
|
// listen on loopback
|
||
|
laddr := ma.StringCast("/ip4/127.0.0.1/tcp/0")
|
||
|
require.NoError(t, h.Network().Listen(laddr))
|
||
|
require.Len(t, h.AllAddrs(), 1)
|
||
|
firstAddr := h.AllAddrs()[0]
|
||
|
require.Equal(t, "/ip4/127.0.0.1", ma.Split(firstAddr)[0].String())
|
||
|
|
||
|
// listen on IPv4 0.0.0.0
|
||
|
require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")))
|
||
|
// should contain localhost and private local addr along with previous listen address
|
||
|
require.Len(t, h.AllAddrs(), 3)
|
||
|
// Should still contain the original addr.
|
||
|
require.True(t, ma.Contains(h.AllAddrs(), firstAddr), "should still contain the original addr")
|
||
|
}
|
||
|
|
||
|
// getHostPair gets a new pair of hosts.
|
||
|
// The first host initiates the connection to the second host.
|
||
|
func getHostPair(t *testing.T) (host.Host, host.Host) {
|
||
|
t.Helper()
|
||
|
|
||
|
h1, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
h1.Start()
|
||
|
h2, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
h2.Start()
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
|
defer cancel()
|
||
|
h2pi := h2.Peerstore().PeerInfo(h2.ID())
|
||
|
require.NoError(t, h1.Connect(ctx, h2pi))
|
||
|
return h1, h2
|
||
|
}
|
||
|
|
||
|
func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) {
|
||
|
t.Helper()
|
||
|
select {
|
||
|
case proto := <-c:
|
||
|
if proto != exp {
|
||
|
t.Fatalf("should have connected on %s, got %s", exp, proto)
|
||
|
}
|
||
|
case <-time.After(time.Second * 5):
|
||
|
t.Fatal("timeout waiting for stream")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestHostProtoPreference(t *testing.T) {
|
||
|
h1, h2 := getHostPair(t)
|
||
|
defer h1.Close()
|
||
|
defer h2.Close()
|
||
|
|
||
|
const (
|
||
|
protoOld = "/testing"
|
||
|
protoNew = "/testing/1.1.0"
|
||
|
protoMinor = "/testing/1.2.0"
|
||
|
)
|
||
|
|
||
|
connectedOn := make(chan protocol.ID)
|
||
|
handler := func(s network.Stream) {
|
||
|
connectedOn <- s.Protocol()
|
||
|
s.Close()
|
||
|
}
|
||
|
|
||
|
// Prevent pushing identify information so this test works.
|
||
|
h1.RemoveStreamHandler(identify.IDPush)
|
||
|
|
||
|
h2.SetStreamHandler(protoOld, handler)
|
||
|
|
||
|
s, err := h1.NewStream(context.Background(), h2.ID(), protoMinor, protoNew, protoOld)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// force the lazy negotiation to complete
|
||
|
_, err = s.Write(nil)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
assertWait(t, connectedOn, protoOld)
|
||
|
s.Close()
|
||
|
|
||
|
h2.SetStreamHandlerMatch(protoMinor, func(protocol.ID) bool { return true }, handler)
|
||
|
// remembered preference will be chosen first, even when the other side newly supports it
|
||
|
s2, err := h1.NewStream(context.Background(), h2.ID(), protoMinor, protoNew, protoOld)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// required to force 'lazy' handshake
|
||
|
_, err = s2.Write([]byte("hello"))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
assertWait(t, connectedOn, protoOld)
|
||
|
s2.Close()
|
||
|
|
||
|
s3, err := h1.NewStream(context.Background(), h2.ID(), protoMinor)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// Force a lazy handshake as we may have received a protocol update by this point.
|
||
|
_, err = s3.Write([]byte("hello"))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
assertWait(t, connectedOn, protoMinor)
|
||
|
s3.Close()
|
||
|
}
|
||
|
|
||
|
func TestHostProtoMismatch(t *testing.T) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
h1, h2 := getHostPair(t)
|
||
|
defer h1.Close()
|
||
|
defer h2.Close()
|
||
|
|
||
|
h1.SetStreamHandler("/super", func(s network.Stream) {
|
||
|
t.Error("shouldnt get here")
|
||
|
s.Reset()
|
||
|
})
|
||
|
|
||
|
_, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/baz/1.0.0")
|
||
|
if err == nil {
|
||
|
t.Fatal("expected new stream to fail")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestHostProtoPreknowledge(t *testing.T) {
|
||
|
h1, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h1.Close()
|
||
|
|
||
|
h2, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h2.Close()
|
||
|
|
||
|
conn := make(chan protocol.ID)
|
||
|
handler := func(s network.Stream) {
|
||
|
conn <- s.Protocol()
|
||
|
s.Close()
|
||
|
}
|
||
|
|
||
|
h2.SetStreamHandler("/super", handler)
|
||
|
|
||
|
h1.Start()
|
||
|
h2.Start()
|
||
|
|
||
|
// Prevent pushing identify information so this test actually _uses_ the super protocol.
|
||
|
h1.RemoveStreamHandler(identify.IDPush)
|
||
|
|
||
|
h2pi := h2.Peerstore().PeerInfo(h2.ID())
|
||
|
// Filter to only 1 address so that we don't have to think about parallel
|
||
|
// connections in this test
|
||
|
h2pi.Addrs = h2pi.Addrs[:1]
|
||
|
require.NoError(t, h1.Connect(context.Background(), h2pi))
|
||
|
|
||
|
// This test implicitly relies on 1 connection. If a background identify
|
||
|
// completes after we set the stream handler below things break
|
||
|
require.Equal(t, 1, len(h1.Network().ConnsToPeer(h2.ID())))
|
||
|
|
||
|
// wait for identify handshake to finish completely
|
||
|
select {
|
||
|
case <-h1.ids.IdentifyWait(h1.Network().ConnsToPeer(h2.ID())[0]):
|
||
|
case <-time.After(time.Second * 5):
|
||
|
t.Fatal("timed out waiting for identify")
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-h2.ids.IdentifyWait(h2.Network().ConnsToPeer(h1.ID())[0]):
|
||
|
case <-time.After(time.Second * 5):
|
||
|
t.Fatal("timed out waiting for identify")
|
||
|
}
|
||
|
|
||
|
h2.SetStreamHandler("/foo", handler)
|
||
|
|
||
|
require.Never(t, func() bool {
|
||
|
protos, err := h1.Peerstore().GetProtocols(h2.ID())
|
||
|
require.NoError(t, err)
|
||
|
for _, p := range protos {
|
||
|
if p == "/foo" {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}, time.Second, 100*time.Millisecond)
|
||
|
|
||
|
s, err := h1.NewStream(context.Background(), h2.ID(), "/foo", "/bar", "/super")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
select {
|
||
|
case p := <-conn:
|
||
|
t.Fatal("shouldn't have gotten connection yet, we should have a lazy stream: ", p)
|
||
|
case <-time.After(time.Millisecond * 50):
|
||
|
}
|
||
|
|
||
|
_, err = s.Read(nil)
|
||
|
require.NoError(t, err)
|
||
|
assertWait(t, conn, "/super")
|
||
|
|
||
|
s.Close()
|
||
|
}
|
||
|
|
||
|
func TestNewDialOld(t *testing.T) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
h1, h2 := getHostPair(t)
|
||
|
defer h1.Close()
|
||
|
defer h2.Close()
|
||
|
|
||
|
connectedOn := make(chan protocol.ID)
|
||
|
h2.SetStreamHandler("/testing", func(s network.Stream) {
|
||
|
connectedOn <- s.Protocol()
|
||
|
s.Close()
|
||
|
})
|
||
|
|
||
|
s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// force the lazy negotiation to complete
|
||
|
_, err = s.Write(nil)
|
||
|
require.NoError(t, err)
|
||
|
assertWait(t, connectedOn, "/testing")
|
||
|
|
||
|
require.Equal(t, s.Protocol(), protocol.ID("/testing"), "should have gotten /testing")
|
||
|
}
|
||
|
|
||
|
func TestNewStreamResolve(t *testing.T) {
|
||
|
h1, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
h1.Start()
|
||
|
h2, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
h2.Start()
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||
|
defer cancel()
|
||
|
|
||
|
// Get the tcp port that h2 is listening on.
|
||
|
h2pi := h2.Peerstore().PeerInfo(h2.ID())
|
||
|
var dialAddr string
|
||
|
const tcpPrefix = "/ip4/127.0.0.1/tcp/"
|
||
|
for _, addr := range h2pi.Addrs {
|
||
|
addrStr := addr.String()
|
||
|
if strings.HasPrefix(addrStr, tcpPrefix) {
|
||
|
port := addrStr[len(tcpPrefix):]
|
||
|
dialAddr = "/dns4/localhost/tcp/" + port
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
assert.NotEqual(t, dialAddr, "")
|
||
|
|
||
|
// Add the DNS multiaddr to h1's peerstore.
|
||
|
maddr, err := ma.NewMultiaddr(dialAddr)
|
||
|
require.NoError(t, err)
|
||
|
h1.Peerstore().AddAddr(h2.ID(), maddr, time.Second)
|
||
|
|
||
|
connectedOn := make(chan protocol.ID)
|
||
|
h2.SetStreamHandler("/testing", func(s network.Stream) {
|
||
|
connectedOn <- s.Protocol()
|
||
|
s.Close()
|
||
|
})
|
||
|
|
||
|
// NewStream will make a new connection using the DNS address in h1's
|
||
|
// peerstore.
|
||
|
s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// force the lazy negotiation to complete
|
||
|
_, err = s.Write(nil)
|
||
|
require.NoError(t, err)
|
||
|
assertWait(t, connectedOn, "/testing")
|
||
|
}
|
||
|
|
||
|
func TestProtoDowngrade(t *testing.T) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
h1, h2 := getHostPair(t)
|
||
|
defer h1.Close()
|
||
|
defer h2.Close()
|
||
|
|
||
|
connectedOn := make(chan protocol.ID)
|
||
|
h2.SetStreamHandler("/testing/1.0.0", func(s network.Stream) {
|
||
|
defer s.Close()
|
||
|
result, err := io.ReadAll(s)
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, string(result), "bar")
|
||
|
connectedOn <- s.Protocol()
|
||
|
})
|
||
|
|
||
|
s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, s.Protocol(), protocol.ID("/testing/1.0.0"), "should have gotten /testing/1.0.0, got %s", s.Protocol())
|
||
|
|
||
|
_, err = s.Write([]byte("bar"))
|
||
|
require.NoError(t, err)
|
||
|
require.NoError(t, s.CloseWrite())
|
||
|
|
||
|
assertWait(t, connectedOn, "/testing/1.0.0")
|
||
|
require.NoError(t, s.Close())
|
||
|
|
||
|
h1.Network().ClosePeer(h2.ID())
|
||
|
h2.RemoveStreamHandler("/testing/1.0.0")
|
||
|
h2.SetStreamHandler("/testing", func(s network.Stream) {
|
||
|
defer s.Close()
|
||
|
result, err := io.ReadAll(s)
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, string(result), "foo")
|
||
|
connectedOn <- s.Protocol()
|
||
|
})
|
||
|
|
||
|
// Give us a second to update our protocol list. This happens async through the event bus.
|
||
|
// This is _almost_ instantaneous, but this test fails once every ~1k runs without this.
|
||
|
time.Sleep(time.Millisecond)
|
||
|
|
||
|
h2pi := h2.Peerstore().PeerInfo(h2.ID())
|
||
|
require.NoError(t, h1.Connect(ctx, h2pi))
|
||
|
|
||
|
s2, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, s2.Protocol(), protocol.ID("/testing"), "should have gotten /testing, got %s, %s", s.Protocol(), s.Conn())
|
||
|
|
||
|
_, err = s2.Write([]byte("foo"))
|
||
|
require.NoError(t, err)
|
||
|
require.NoError(t, s2.CloseWrite())
|
||
|
|
||
|
assertWait(t, connectedOn, "/testing")
|
||
|
}
|
||
|
|
||
|
func TestAddrChangeImmediatelyIfAddressNonEmpty(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
taddrs := []ma.Multiaddr{ma.StringCast("/ip4/1.2.3.4/tcp/1234")}
|
||
|
|
||
|
starting := make(chan struct{})
|
||
|
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{AddrsFactory: func(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||
|
<-starting
|
||
|
return taddrs
|
||
|
}})
|
||
|
require.NoError(t, err)
|
||
|
defer h.Close()
|
||
|
|
||
|
sub, err := h.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{})
|
||
|
close(starting)
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
defer sub.Close()
|
||
|
h.Start()
|
||
|
|
||
|
expected := event.EvtLocalAddressesUpdated{
|
||
|
Diffs: true,
|
||
|
Current: []event.UpdatedAddress{
|
||
|
{Action: event.Added, Address: ma.StringCast("/ip4/1.2.3.4/tcp/1234")},
|
||
|
},
|
||
|
Removed: []event.UpdatedAddress{}}
|
||
|
|
||
|
// assert we get expected event
|
||
|
evt := waitForAddrChangeEvent(ctx, sub, t)
|
||
|
if !updatedAddrEventsEqual(expected, evt) {
|
||
|
t.Errorf("change events not equal: \n\texpected: %v \n\tactual: %v", expected, evt)
|
||
|
}
|
||
|
|
||
|
// assert it's on the signed record
|
||
|
rc := peerRecordFromEnvelope(t, evt.SignedPeerRecord)
|
||
|
require.Equal(t, taddrs, rc.Addrs)
|
||
|
|
||
|
// assert it's in the peerstore
|
||
|
ev := h.Peerstore().(peerstore.CertifiedAddrBook).GetPeerRecord(h.ID())
|
||
|
require.NotNil(t, ev)
|
||
|
rc = peerRecordFromEnvelope(t, ev)
|
||
|
require.Equal(t, taddrs, rc.Addrs)
|
||
|
}
|
||
|
|
||
|
func TestStatefulAddrEvents(t *testing.T) {
|
||
|
h, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
h.Start()
|
||
|
defer h.Close()
|
||
|
|
||
|
sub, err := h.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{}, eventbus.BufSize(10))
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
defer sub.Close()
|
||
|
|
||
|
select {
|
||
|
case v := <-sub.Out():
|
||
|
assert.NotNil(t, v)
|
||
|
case <-time.After(time.Second * 5):
|
||
|
t.Error("timed out waiting for event")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestHostAddrChangeDetection(t *testing.T) {
|
||
|
// This test uses the address factory to provide several
|
||
|
// sets of listen addresses for the host. It advances through
|
||
|
// the sets by changing the currentAddrSet index var below.
|
||
|
addrSets := [][]ma.Multiaddr{
|
||
|
{},
|
||
|
{ma.StringCast("/ip4/1.2.3.4/tcp/1234")},
|
||
|
{ma.StringCast("/ip4/1.2.3.4/tcp/1234"), ma.StringCast("/ip4/2.3.4.5/tcp/1234")},
|
||
|
{ma.StringCast("/ip4/2.3.4.5/tcp/1234"), ma.StringCast("/ip4/3.4.5.6/tcp/4321")},
|
||
|
}
|
||
|
|
||
|
// The events we expect the host to emit when SignalAddressChange is called
|
||
|
// and the changes between addr sets are detected
|
||
|
expectedEvents := []event.EvtLocalAddressesUpdated{
|
||
|
{
|
||
|
Diffs: true,
|
||
|
Current: []event.UpdatedAddress{
|
||
|
{Action: event.Added, Address: ma.StringCast("/ip4/1.2.3.4/tcp/1234")},
|
||
|
},
|
||
|
Removed: []event.UpdatedAddress{},
|
||
|
},
|
||
|
{
|
||
|
Diffs: true,
|
||
|
Current: []event.UpdatedAddress{
|
||
|
{Action: event.Maintained, Address: ma.StringCast("/ip4/1.2.3.4/tcp/1234")},
|
||
|
{Action: event.Added, Address: ma.StringCast("/ip4/2.3.4.5/tcp/1234")},
|
||
|
},
|
||
|
Removed: []event.UpdatedAddress{},
|
||
|
},
|
||
|
{
|
||
|
Diffs: true,
|
||
|
Current: []event.UpdatedAddress{
|
||
|
{Action: event.Added, Address: ma.StringCast("/ip4/3.4.5.6/tcp/4321")},
|
||
|
{Action: event.Maintained, Address: ma.StringCast("/ip4/2.3.4.5/tcp/1234")},
|
||
|
},
|
||
|
Removed: []event.UpdatedAddress{
|
||
|
{Action: event.Removed, Address: ma.StringCast("/ip4/1.2.3.4/tcp/1234")},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
var lk sync.Mutex
|
||
|
currentAddrSet := 0
|
||
|
addrsFactory := func(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||
|
lk.Lock()
|
||
|
defer lk.Unlock()
|
||
|
return addrSets[currentAddrSet]
|
||
|
}
|
||
|
|
||
|
ctx := context.Background()
|
||
|
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{AddrsFactory: addrsFactory})
|
||
|
require.NoError(t, err)
|
||
|
h.Start()
|
||
|
defer h.Close()
|
||
|
|
||
|
sub, err := h.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{}, eventbus.BufSize(10))
|
||
|
require.NoError(t, err)
|
||
|
defer sub.Close()
|
||
|
|
||
|
// wait for the host background thread to start
|
||
|
time.Sleep(1 * time.Second)
|
||
|
// host should start with no addrs (addrSet 0)
|
||
|
addrs := h.Addrs()
|
||
|
if len(addrs) != 0 {
|
||
|
t.Fatalf("expected 0 addrs, got %d", len(addrs))
|
||
|
}
|
||
|
|
||
|
// change addr, signal and assert event
|
||
|
for i := 1; i < len(addrSets); i++ {
|
||
|
lk.Lock()
|
||
|
currentAddrSet = i
|
||
|
lk.Unlock()
|
||
|
h.SignalAddressChange()
|
||
|
evt := waitForAddrChangeEvent(ctx, sub, t)
|
||
|
if !updatedAddrEventsEqual(expectedEvents[i-1], evt) {
|
||
|
t.Errorf("change events not equal: \n\texpected: %v \n\tactual: %v", expectedEvents[i-1], evt)
|
||
|
}
|
||
|
|
||
|
// assert it's on the signed record
|
||
|
rc := peerRecordFromEnvelope(t, evt.SignedPeerRecord)
|
||
|
require.Equal(t, addrSets[i], rc.Addrs)
|
||
|
|
||
|
// assert it's in the peerstore
|
||
|
ev := h.Peerstore().(peerstore.CertifiedAddrBook).GetPeerRecord(h.ID())
|
||
|
require.NotNil(t, ev)
|
||
|
rc = peerRecordFromEnvelope(t, ev)
|
||
|
require.Equal(t, addrSets[i], rc.Addrs)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestNegotiationCancel(t *testing.T) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
h1, h2 := getHostPair(t)
|
||
|
defer h1.Close()
|
||
|
defer h2.Close()
|
||
|
|
||
|
// pre-negotiation so we can make the negotiation hang.
|
||
|
h2.Network().SetStreamHandler(func(s network.Stream) {
|
||
|
<-ctx.Done() // wait till the test is done.
|
||
|
s.Reset()
|
||
|
})
|
||
|
|
||
|
ctx2, cancel2 := context.WithCancel(ctx)
|
||
|
defer cancel2()
|
||
|
|
||
|
errCh := make(chan error, 1)
|
||
|
go func() {
|
||
|
s, err := h1.NewStream(ctx2, h2.ID(), "/testing")
|
||
|
if s != nil {
|
||
|
errCh <- fmt.Errorf("expected to fail negotiation")
|
||
|
return
|
||
|
}
|
||
|
errCh <- err
|
||
|
}()
|
||
|
select {
|
||
|
case err := <-errCh:
|
||
|
t.Fatal(err)
|
||
|
case <-time.After(10 * time.Millisecond):
|
||
|
// ok, hung.
|
||
|
}
|
||
|
cancel2()
|
||
|
|
||
|
select {
|
||
|
case err := <-errCh:
|
||
|
require.ErrorIs(t, err, context.Canceled)
|
||
|
case <-time.After(500 * time.Millisecond):
|
||
|
// failed to cancel
|
||
|
t.Fatal("expected negotiation to be canceled")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func waitForAddrChangeEvent(ctx context.Context, sub event.Subscription, t *testing.T) event.EvtLocalAddressesUpdated {
|
||
|
t.Helper()
|
||
|
for {
|
||
|
select {
|
||
|
case evt, more := <-sub.Out():
|
||
|
if !more {
|
||
|
t.Fatal("channel should not be closed")
|
||
|
}
|
||
|
return evt.(event.EvtLocalAddressesUpdated)
|
||
|
case <-ctx.Done():
|
||
|
t.Fatal("context should not have cancelled")
|
||
|
case <-time.After(5 * time.Second):
|
||
|
t.Fatal("timed out waiting for address change event")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// updatedAddrsEqual is a helper to check whether two lists of
|
||
|
// event.UpdatedAddress have the same contents, ignoring ordering.
|
||
|
func updatedAddrsEqual(a, b []event.UpdatedAddress) bool {
|
||
|
if len(a) != len(b) {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// We can't use an UpdatedAddress directly as a map key, since
|
||
|
// Multiaddr is an interface, and go won't know how to compare
|
||
|
// for equality. So we convert to this little struct, which
|
||
|
// stores the multiaddr as a string.
|
||
|
type ua struct {
|
||
|
action event.AddrAction
|
||
|
addrStr string
|
||
|
}
|
||
|
aSet := make(map[ua]struct{})
|
||
|
for _, addr := range a {
|
||
|
k := ua{action: addr.Action, addrStr: string(addr.Address.Bytes())}
|
||
|
aSet[k] = struct{}{}
|
||
|
}
|
||
|
for _, addr := range b {
|
||
|
k := ua{action: addr.Action, addrStr: string(addr.Address.Bytes())}
|
||
|
_, ok := aSet[k]
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// updatedAddrEventsEqual is a helper to check whether two
|
||
|
// event.EvtLocalAddressesUpdated are equal, ignoring the ordering of
|
||
|
// addresses in the inner lists.
|
||
|
func updatedAddrEventsEqual(a, b event.EvtLocalAddressesUpdated) bool {
|
||
|
return a.Diffs == b.Diffs &&
|
||
|
updatedAddrsEqual(a.Current, b.Current) &&
|
||
|
updatedAddrsEqual(a.Removed, b.Removed)
|
||
|
}
|
||
|
|
||
|
func peerRecordFromEnvelope(t *testing.T, ev *record.Envelope) *peer.PeerRecord {
|
||
|
t.Helper()
|
||
|
rec, err := ev.Record()
|
||
|
if err != nil {
|
||
|
t.Fatalf("error getting PeerRecord from event: %v", err)
|
||
|
return nil
|
||
|
}
|
||
|
peerRec, ok := rec.(*peer.PeerRecord)
|
||
|
if !ok {
|
||
|
t.Fatalf("wrong type for peer record")
|
||
|
return nil
|
||
|
}
|
||
|
return peerRec
|
||
|
}
|
||
|
|
||
|
func TestNormalizeMultiaddr(t *testing.T) {
|
||
|
h1, err := NewHost(swarmt.GenSwarm(t), nil)
|
||
|
require.NoError(t, err)
|
||
|
defer h1.Close()
|
||
|
|
||
|
require.Equal(t, "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport", h1.NormalizeMultiaddr(ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")).String())
|
||
|
}
|
||
|
|
||
|
func TestInferWebtransportAddrsFromQuic(t *testing.T) {
|
||
|
type testCase struct {
|
||
|
name string
|
||
|
in []string
|
||
|
out []string
|
||
|
}
|
||
|
|
||
|
testCases := []testCase{
|
||
|
{
|
||
|
name: "Happy Path",
|
||
|
in: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/9999/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1"},
|
||
|
out: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/9999/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1", "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport"},
|
||
|
},
|
||
|
{
|
||
|
name: "Already discovered",
|
||
|
in: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/9999/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1", "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport"},
|
||
|
out: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/9999/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1", "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport"},
|
||
|
},
|
||
|
{
|
||
|
name: "Infer Many",
|
||
|
in: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/9999/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1", "/ip4/4.3.2.1/udp/9999/quic-v1"},
|
||
|
out: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/9999/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1", "/ip4/4.3.2.1/udp/9999/quic-v1", "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport", "/ip4/4.3.2.1/udp/9999/quic-v1/webtransport"},
|
||
|
},
|
||
|
{
|
||
|
name: "No Common listeners",
|
||
|
in: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/1111/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1"},
|
||
|
out: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/0.0.0.0/udp/1111/quic-v1/webtransport", "/ip4/1.2.3.4/udp/9999/quic-v1"},
|
||
|
},
|
||
|
{
|
||
|
name: "No WebTransport",
|
||
|
in: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/1.2.3.4/udp/9999/quic-v1"},
|
||
|
out: []string{"/ip4/0.0.0.0/udp/9999/quic-v1", "/ip4/1.2.3.4/udp/9999/quic-v1"},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
// Make sure the testCases are all valid multiaddrs
|
||
|
for _, tc := range testCases {
|
||
|
for _, addr := range tc.in {
|
||
|
_, err := ma.NewMultiaddr(addr)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
for _, addr := range tc.out {
|
||
|
_, err := ma.NewMultiaddr(addr)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
sort.StringSlice(tc.in).Sort()
|
||
|
sort.StringSlice(tc.out).Sort()
|
||
|
min := make([]ma.Multiaddr, 0, len(tc.in))
|
||
|
sort.Slice(tc.in, func(i, j int) bool {
|
||
|
return tc.in[i] < tc.in[j]
|
||
|
})
|
||
|
for _, addr := range tc.in {
|
||
|
min = append(min, ma.StringCast(addr))
|
||
|
}
|
||
|
outMa := inferWebtransportAddrsFromQuic(min)
|
||
|
outStr := make([]string, 0, len(outMa))
|
||
|
for _, addr := range outMa {
|
||
|
outStr = append(outStr, addr.String())
|
||
|
}
|
||
|
require.Equal(t, tc.out, outStr)
|
||
|
})
|
||
|
|
||
|
}
|
||
|
|
||
|
}
|