mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2025-01-24 14:45:17 +00:00
249 lines
6.2 KiB
Go
249 lines
6.2 KiB
Go
|
package udpmux
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pion/stun"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func getSTUNBindingRequest(ufrag string) *stun.Message {
|
||
|
msg := stun.New()
|
||
|
msg.SetType(stun.BindingRequest)
|
||
|
uattr := stun.RawAttribute{
|
||
|
Type: stun.AttrUsername,
|
||
|
Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections
|
||
|
}
|
||
|
uattr.AddTo(msg)
|
||
|
msg.Encode()
|
||
|
return msg
|
||
|
}
|
||
|
|
||
|
func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) {
|
||
|
t.Helper()
|
||
|
msg := getSTUNBindingRequest(ufrag)
|
||
|
_, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0])
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
func newPacketConn(t *testing.T) net.PacketConn {
|
||
|
t.Helper()
|
||
|
udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
|
||
|
c, err := net.ListenUDP("udp", udpPort0)
|
||
|
require.NoError(t, err)
|
||
|
t.Cleanup(func() { c.Close() })
|
||
|
return c
|
||
|
}
|
||
|
|
||
|
func TestAccept(t *testing.T) {
|
||
|
c := newPacketConn(t)
|
||
|
defer c.Close()
|
||
|
m := NewUDPMux(c)
|
||
|
m.Start()
|
||
|
defer m.Close()
|
||
|
|
||
|
ufrags := []string{"a", "b", "c", "d"}
|
||
|
conns := make([]net.PacketConn, len(ufrags))
|
||
|
for i, ufrag := range ufrags {
|
||
|
conns[i] = newPacketConn(t)
|
||
|
setupMapping(t, ufrag, conns[i], m)
|
||
|
}
|
||
|
for i, ufrag := range ufrags {
|
||
|
c, err := m.Accept(context.Background())
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, c.Ufrag, ufrag)
|
||
|
require.Equal(t, c.Addr, conns[i].LocalAddr())
|
||
|
}
|
||
|
|
||
|
for i, ufrag := range ufrags {
|
||
|
// should not be accepted
|
||
|
setupMapping(t, ufrag, conns[i], m)
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||
|
defer cancel()
|
||
|
_, err := m.Accept(ctx)
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// should not be accepted
|
||
|
cc := newPacketConn(t)
|
||
|
setupMapping(t, ufrag, cc, m)
|
||
|
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||
|
defer cancel()
|
||
|
_, err = m.Accept(ctx)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetConn(t *testing.T) {
|
||
|
c := newPacketConn(t)
|
||
|
m := NewUDPMux(c)
|
||
|
m.Start()
|
||
|
defer m.Close()
|
||
|
|
||
|
ufrags := []string{"a", "b", "c", "d"}
|
||
|
conns := make([]net.PacketConn, len(ufrags))
|
||
|
for i, ufrag := range ufrags {
|
||
|
conns[i] = newPacketConn(t)
|
||
|
setupMapping(t, ufrag, conns[i], m)
|
||
|
}
|
||
|
for i, ufrag := range ufrags {
|
||
|
c, err := m.Accept(context.Background())
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, c.Ufrag, ufrag)
|
||
|
require.Equal(t, c.Addr, conns[i].LocalAddr())
|
||
|
}
|
||
|
|
||
|
for i, ufrag := range ufrags {
|
||
|
c, err := m.GetConn(ufrag, conns[i].LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
msg := make([]byte, 100)
|
||
|
_, _, err = c.ReadFrom(msg)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
for i, ufrag := range ufrags {
|
||
|
cc := newPacketConn(t)
|
||
|
// setupMapping of cc to ufrags[0] and remove the stun binding request from the queue
|
||
|
setupMapping(t, ufrag, cc, m)
|
||
|
mc, err := m.GetConn(ufrag, cc.LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
msg := make([]byte, 100)
|
||
|
_, _, err = mc.ReadFrom(msg)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// Write from new connection should provide the new address on ReadFrom
|
||
|
_, err = cc.WriteTo([]byte("test1"), c.LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
n, addr, err := mc.ReadFrom(msg)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, addr, cc.LocalAddr())
|
||
|
require.Equal(t, "test1", string(msg[:n]))
|
||
|
|
||
|
// Write from original connection should provide the original address
|
||
|
_, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
n, addr, err = mc.ReadFrom(msg)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, addr, conns[i].LocalAddr())
|
||
|
require.Equal(t, "test2", string(msg[:n]))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRemoveConnByUfrag(t *testing.T) {
|
||
|
c := newPacketConn(t)
|
||
|
m := NewUDPMux(c)
|
||
|
m.Start()
|
||
|
defer m.Close()
|
||
|
|
||
|
// Map each ufrag to two addresses
|
||
|
ufrag := "a"
|
||
|
count := 10
|
||
|
conns := make([]net.PacketConn, count)
|
||
|
for i := 0; i < 10; i++ {
|
||
|
conns[i] = newPacketConn(t)
|
||
|
setupMapping(t, ufrag, conns[i], m)
|
||
|
}
|
||
|
mc, err := m.GetConn(ufrag, conns[0].LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
for i := 0; i < 10; i++ {
|
||
|
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
if mc1 != mc {
|
||
|
t.Fatalf("expected the two muxed connections to be same")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Now remove the ufrag
|
||
|
m.RemoveConnByUfrag(ufrag)
|
||
|
|
||
|
// All connections should now be associated with b
|
||
|
ufrag = "b"
|
||
|
for i := 0; i < 10; i++ {
|
||
|
setupMapping(t, ufrag, conns[i], m)
|
||
|
}
|
||
|
mc, err = m.GetConn(ufrag, conns[0].LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
for i := 0; i < 10; i++ {
|
||
|
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
if mc1 != mc {
|
||
|
t.Fatalf("expected the two muxed connections to be same")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Should be different even if the address is the same
|
||
|
mc1, err := m.GetConn("a", conns[0].LocalAddr())
|
||
|
require.NoError(t, err)
|
||
|
if mc1 == mc {
|
||
|
t.Fatalf("expected the two connections to be different")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMuxedConnection(t *testing.T) {
|
||
|
c := newPacketConn(t)
|
||
|
m := NewUDPMux(c)
|
||
|
m.Start()
|
||
|
defer m.Close()
|
||
|
|
||
|
msgCount := 3
|
||
|
connCount := 3
|
||
|
|
||
|
ufrags := []string{"a", "b", "c"}
|
||
|
addrUfragMap := make(map[string]string)
|
||
|
ufragConnsMap := make(map[string][]net.PacketConn)
|
||
|
for _, ufrag := range ufrags {
|
||
|
for i := 0; i < connCount; i++ {
|
||
|
cc := newPacketConn(t)
|
||
|
addrUfragMap[cc.LocalAddr().String()] = ufrag
|
||
|
ufragConnsMap[ufrag] = append(ufragConnsMap[ufrag], cc)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
done := make(chan bool, len(ufrags))
|
||
|
for _, ufrag := range ufrags {
|
||
|
go func(ufrag string) {
|
||
|
for _, cc := range ufragConnsMap[ufrag] {
|
||
|
setupMapping(t, ufrag, cc, m)
|
||
|
for j := 0; j < msgCount; j++ {
|
||
|
cc.WriteTo([]byte(ufrag), c.LocalAddr())
|
||
|
}
|
||
|
}
|
||
|
done <- true
|
||
|
}(ufrag)
|
||
|
}
|
||
|
for i := 0; i < len(ufrags); i++ {
|
||
|
<-done
|
||
|
}
|
||
|
|
||
|
for _, ufrag := range ufrags {
|
||
|
mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant
|
||
|
require.NoError(t, err)
|
||
|
msgs := 0
|
||
|
stunRequests := 0
|
||
|
msg := make([]byte, 1500)
|
||
|
addrPacketCount := make(map[string]int)
|
||
|
for i := 0; i < connCount; i++ {
|
||
|
for j := 0; j < msgCount+1; j++ {
|
||
|
n, addr1, err := mc.ReadFrom(msg)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, addrUfragMap[addr1.String()], ufrag)
|
||
|
addrPacketCount[addr1.String()]++
|
||
|
if stun.IsMessage(msg[:n]) {
|
||
|
stunRequests++
|
||
|
} else {
|
||
|
msgs++
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
for addr, v := range addrPacketCount {
|
||
|
require.Equal(t, v, msgCount+1) // msgCount msgs + 1 STUN binding request
|
||
|
delete(addrUfragMap, addr)
|
||
|
}
|
||
|
require.Len(t, addrPacketCount, connCount)
|
||
|
}
|
||
|
require.Empty(t, addrUfragMap)
|
||
|
}
|