mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2024-12-26 16:45:18 +00:00
206 lines
4.4 KiB
Go
206 lines
4.4 KiB
Go
package mocknet
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/libp2p/go-libp2p/core/network"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNotifications(t *testing.T) {
|
|
const swarmSize = 5
|
|
const timeout = 10 * time.Second
|
|
|
|
mn, err := FullMeshLinked(swarmSize)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer mn.Close()
|
|
|
|
// signup notifs
|
|
nets := mn.Nets()
|
|
notifiees := make(map[peer.ID]*netNotifiee, len(nets))
|
|
for _, pn := range nets {
|
|
defer pn.Close()
|
|
|
|
n := newNetNotifiee(t, swarmSize)
|
|
pn.Notify(n)
|
|
notifiees[pn.LocalPeer()] = n
|
|
}
|
|
|
|
// connect all but self
|
|
if err := mn.ConnectAllButSelf(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// test everyone got the correct connection opened calls
|
|
for _, s1 := range nets {
|
|
n := notifiees[s1.LocalPeer()]
|
|
notifs := make(map[peer.ID][]network.Conn)
|
|
for _, s2 := range nets {
|
|
if s2 == s1 {
|
|
continue
|
|
}
|
|
|
|
// this feels a little sketchy, but its probably okay
|
|
for len(s1.ConnsToPeer(s2.LocalPeer())) != len(notifs[s2.LocalPeer()]) {
|
|
select {
|
|
case c := <-n.connected:
|
|
nfp := notifs[c.RemotePeer()]
|
|
notifs[c.RemotePeer()] = append(nfp, c)
|
|
case <-time.After(timeout):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
}
|
|
|
|
for p, cons := range notifs {
|
|
expect := s1.ConnsToPeer(p)
|
|
if len(expect) != len(cons) {
|
|
t.Fatal("got different number of connections")
|
|
}
|
|
|
|
for _, c := range cons {
|
|
var found bool
|
|
for _, c2 := range expect {
|
|
if c == c2 {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Fatal("connection not found!")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
acceptedStream := make(chan struct{}, 1000)
|
|
for _, s := range nets {
|
|
s.SetStreamHandler(func(s network.Stream) {
|
|
acceptedStream <- struct{}{}
|
|
s.Close()
|
|
})
|
|
}
|
|
|
|
// Make sure we've received at last one stream per conn.
|
|
for _, s := range nets {
|
|
conns := s.Conns()
|
|
for _, c := range conns {
|
|
st1, err := c.NewStream(context.Background())
|
|
if err != nil {
|
|
t.Error(err)
|
|
continue
|
|
}
|
|
t.Logf("%s %s <--%p--> %s %s", c.LocalPeer(), c.LocalMultiaddr(), st1, c.RemotePeer(), c.RemoteMultiaddr())
|
|
st1.Close()
|
|
}
|
|
}
|
|
|
|
// close conns
|
|
for _, s1 := range nets {
|
|
n1 := notifiees[s1.LocalPeer()]
|
|
for _, c1 := range s1.Conns() {
|
|
c2 := ConnComplement(c1)
|
|
|
|
n2 := notifiees[c2.LocalPeer()]
|
|
c1.Close()
|
|
|
|
var c3, c4 network.Conn
|
|
select {
|
|
case c3 = <-n1.disconnected:
|
|
case <-time.After(timeout):
|
|
t.Fatal("timeout")
|
|
}
|
|
if c1 != c3 {
|
|
t.Fatal("got incorrect conn", c1, c3)
|
|
}
|
|
|
|
select {
|
|
case c4 = <-n2.disconnected:
|
|
case <-time.After(timeout):
|
|
t.Fatal("timeout")
|
|
}
|
|
if c2 != c4 {
|
|
t.Fatal("got incorrect conn", c1, c2)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, n1 := range notifiees {
|
|
// Avoid holding this lock while waiting, otherwise we can deadlock.
|
|
streamStateCopy := map[network.Stream]chan struct{}{}
|
|
n1.streamState.Lock()
|
|
for str, ch := range n1.streamState.m {
|
|
streamStateCopy[str] = ch
|
|
}
|
|
n1.streamState.Unlock()
|
|
|
|
for str1, ch1 := range streamStateCopy {
|
|
<-ch1
|
|
str2 := StreamComplement(str1)
|
|
n2 := notifiees[str1.Conn().RemotePeer()]
|
|
|
|
// make sure the OpenedStream notification was processed first
|
|
var ch2 chan struct{}
|
|
require.Eventually(t, func() bool {
|
|
n2.streamState.Lock()
|
|
defer n2.streamState.Unlock()
|
|
ch, ok := n2.streamState.m[str2]
|
|
if ok {
|
|
ch2 = ch
|
|
}
|
|
return ok
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
<-ch2
|
|
}
|
|
}
|
|
}
|
|
|
|
type netNotifiee struct {
|
|
t *testing.T
|
|
|
|
listen chan ma.Multiaddr
|
|
listenClose chan ma.Multiaddr
|
|
connected chan network.Conn
|
|
disconnected chan network.Conn
|
|
|
|
streamState struct {
|
|
sync.Mutex
|
|
m map[network.Stream]chan struct{}
|
|
}
|
|
}
|
|
|
|
func newNetNotifiee(t *testing.T, buffer int) *netNotifiee {
|
|
nn := &netNotifiee{
|
|
t: t,
|
|
listen: make(chan ma.Multiaddr, 1),
|
|
listenClose: make(chan ma.Multiaddr, 1),
|
|
connected: make(chan network.Conn, buffer*2),
|
|
disconnected: make(chan network.Conn, buffer*2),
|
|
}
|
|
nn.streamState.m = make(map[network.Stream]chan struct{})
|
|
return nn
|
|
}
|
|
|
|
func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {
|
|
nn.listen <- a
|
|
}
|
|
func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {
|
|
nn.listenClose <- a
|
|
}
|
|
func (nn *netNotifiee) Connected(n network.Network, v network.Conn) {
|
|
nn.connected <- v
|
|
}
|
|
func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {
|
|
nn.disconnected <- v
|
|
}
|