package blossomsub

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/base64"
	"fmt"
	"io"
	"math/rand"
	"sort"
	"sync"
	"testing"
	"time"

	pb "source.quilibrium.com/quilibrium/monorepo/go-libp2p-blossomsub/pb"

	"github.com/libp2p/go-libp2p/core/host"
	"github.com/libp2p/go-libp2p/core/network"
	"github.com/libp2p/go-libp2p/core/peer"
	"github.com/libp2p/go-libp2p/core/protocol"

	bhost "github.com/libp2p/go-libp2p/p2p/host/blank"
	swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"

	"github.com/libp2p/go-msgio/protoio"
)

func checkMessageRouting(t *testing.T, bitmask []byte, pubs []*PubSub, subs []*Subscription) {
	data := make([]byte, 16)
	rand.Read(data)

	for _, p := range pubs {
		err := p.Publish(bitmask, data)
		if err != nil {
			t.Fatal(err)
		}

		for _, s := range subs {
			assertReceive(t, s, data)
		}
	}
}

func getNetHosts(t *testing.T, ctx context.Context, n int) []host.Host {
	var out []host.Host

	for i := 0; i < n; i++ {
		netw := swarmt.GenSwarm(t)
		h := bhost.NewBlankHost(netw)
		t.Cleanup(func() { h.Close() })
		out = append(out, h)
	}

	return out
}

func connect(t *testing.T, a, b host.Host) {
	pinfo := a.Peerstore().PeerInfo(a.ID())
	err := b.Connect(context.Background(), pinfo)
	if err != nil {
		t.Fatal(err)
	}
}

func sparseConnect(t *testing.T, hosts []host.Host) {
	connectSome(t, hosts, 3)
}

func denseConnect(t *testing.T, hosts []host.Host) {
	connectSome(t, hosts, 10)
}

func connectSome(t *testing.T, hosts []host.Host, d int) {
	for i, a := range hosts {
		for j := 0; j < d; j++ {
			n := rand.Intn(len(hosts))
			if n == i {
				j--
				continue
			}

			b := hosts[n]

			connect(t, a, b)
		}
	}
}

func connectAll(t *testing.T, hosts []host.Host) {
	for i, a := range hosts {
		for j, b := range hosts {
			if i == j {
				continue
			}

			connect(t, a, b)
		}
	}
}

func getPubsub(ctx context.Context, h host.Host, opts ...Option) *PubSub {
	ps, err := NewFloodSub(ctx, h, opts...)
	if err != nil {
		panic(err)
	}
	return ps
}

func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub {
	var psubs []*PubSub
	for _, h := range hs {
		psubs = append(psubs, getPubsub(ctx, h, opts...))
	}
	return psubs
}

func getPubsubsWithOptionC(ctx context.Context, hs []host.Host, cons ...func(int) Option) []*PubSub {
	var psubs []*PubSub
	for _, h := range hs {
		var opts []Option
		for i, c := range cons {
			opts = append(opts, c(i))
		}
		psubs = append(psubs, getPubsub(ctx, h, opts...))
	}
	return psubs
}

func assertReceive(t *testing.T, ch *Subscription, exp []byte) {
	select {
	case msg := <-ch.ch:
		if !bytes.Equal(msg.GetData(), exp) {
			t.Fatalf("got wrong message, expected %s but got %s", string(exp), string(msg.GetData()))
		}
	case <-time.After(time.Second * 5):
		t.Logf("%#v\n", ch)
		t.Fatal("timed out waiting for message of: ", string(exp))
	}
}

func assertNeverReceives(t *testing.T, ch *Subscription, timeout time.Duration) {
	select {
	case msg := <-ch.ch:
		t.Logf("%#v\n", ch)
		t.Fatal("got unexpected message: ", string(msg.GetData()))
	case <-time.After(timeout):
	}
}

func TestBasicFloodsub(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	hosts := getNetHosts(t, ctx, 20)

	psubs := getPubsubs(ctx, hosts)

	var msgs []*Subscription
	for _, ps := range psubs {
		subch, err := ps.Subscribe([]byte{0xf0, 0x0b, 0xa1, 0x20})
		if err != nil {
			t.Fatal(err)
		}

		msgs = append(msgs, subch)
	}

	// connectAll(t, hosts)
	sparseConnect(t, hosts)

	time.Sleep(time.Millisecond * 100)

	for i := 0; i < 100; i++ {
		msg := []byte(fmt.Sprintf("%d the flooooooood %d", i, i))

		owner := rand.Intn(len(psubs))

		psubs[owner].Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, msg)

		for _, sub := range msgs {
			got, err := sub.Next(ctx)
			if err != nil {
				t.Fatal(sub.err)
			}
			if !bytes.Equal(msg, got.Data) {
				t.Fatal("got wrong message!")
			}
		}
	}
}

func TestMultihops(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 6)

	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])
	connect(t, hosts[1], hosts[2])
	connect(t, hosts[2], hosts[3])
	connect(t, hosts[3], hosts[4])
	connect(t, hosts[4], hosts[5])

	var subs []*Subscription
	for i := 1; i < 6; i++ {
		ch, err := psubs[i].Subscribe([]byte{0xf0, 0x0b, 0xa1, 0x20})
		if err != nil {
			t.Fatal(err)
		}
		subs = append(subs, ch)
	}

	time.Sleep(time.Millisecond * 100)

	msg := []byte("i like cats")
	err := psubs[0].Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, msg)
	if err != nil {
		t.Fatal(err)
	}

	// last node in the chain should get the message
	select {
	case out := <-subs[4].ch:
		if !bytes.Equal(out.GetData(), msg) {
			t.Fatal("got wrong data")
		}
	case <-time.After(time.Second * 5):
		t.Fatal("timed out waiting for message")
	}
}

func TestReconnects(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 3)

	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])
	connect(t, hosts[0], hosts[2])

	A, err := psubs[1].Subscribe([]byte{0xca, 0x75})
	if err != nil {
		t.Fatal(err)
	}

	B, err := psubs[2].Subscribe([]byte{0xca, 0x75})
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 100)

	msg := []byte("apples and oranges")
	err = psubs[0].Publish([]byte{0xca, 0x75}, msg)
	if err != nil {
		t.Fatal(err)
	}

	assertReceive(t, A, msg)
	assertReceive(t, B, msg)

	B.Cancel()

	time.Sleep(time.Millisecond * 50)

	msg2 := []byte("potato")
	err = psubs[0].Publish([]byte{0xca, 0x75}, msg2)
	if err != nil {
		t.Fatal(err)
	}

	assertReceive(t, A, msg2)
	select {
	case _, ok := <-B.ch:
		if ok {
			t.Fatal("shouldnt have gotten data on this channel")
		}
	case <-time.After(time.Second):
		t.Fatal("timed out waiting for B chan to be closed")
	}

	nSubs := len(psubs[2].mySubs[string([]byte{0xca, 0x75})])
	if nSubs > 0 {
		t.Fatal(`B should have 0 subscribers for channel []byte{0xca,0x75}, has`, nSubs)
	}

	ch2, err := psubs[2].Subscribe([]byte{0xca, 0x75})
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 100)

	nextmsg := []byte("ifps is kul")
	err = psubs[0].Publish([]byte{0xca, 0x75}, nextmsg)
	if err != nil {
		t.Fatal(err)
	}

	assertReceive(t, ch2, nextmsg)
}

// make sure messages arent routed between nodes who arent subscribed
func TestNoConnection(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 10)

	psubs := getPubsubs(ctx, hosts)

	ch, err := psubs[5].Subscribe([]byte{0xf0, 0x0b, 0xa1, 0x20})
	if err != nil {
		t.Fatal(err)
	}

	err = psubs[0].Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, []byte("TESTING"))
	if err != nil {
		t.Fatal(err)
	}

	select {
	case <-ch.ch:
		t.Fatal("shouldnt have gotten a message")
	case <-time.After(time.Millisecond * 200):
	}
}

func TestSelfReceive(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	host := getNetHosts(t, ctx, 1)[0]

	psub, err := NewFloodSub(ctx, host)
	if err != nil {
		t.Fatal(err)
	}

	msg := []byte("hello world")

	err = psub.Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, msg)
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 10)

	ch, err := psub.Subscribe([]byte{0xf0, 0x0b, 0xa1, 0x20})
	if err != nil {
		t.Fatal(err)
	}

	msg2 := []byte("goodbye world")
	err = psub.Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, msg2)
	if err != nil {
		t.Fatal(err)
	}

	assertReceive(t, ch, msg2)
}

func TestOneToOne(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])

	sub, err := psubs[1].Subscribe([]byte{0xf0, 0x0b, 0xa1, 0x20})
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 50)

	checkMessageRouting(t, []byte{0xf0, 0x0b, 0xa1, 0x20}, psubs, []*Subscription{sub})
}

func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
	peers := ps.ListPeers([]byte{})
	set := make(map[peer.ID]struct{})
	for _, p := range peers {
		set[p] = struct{}{}
	}

	for _, h := range has {
		if _, ok := set[hosts[h].ID()]; !ok {
			t.Fatal("expected to have connection to peer: ", h)
		}
	}
}

func TestTreeTopology(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 10)
	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])
	connect(t, hosts[1], hosts[2])
	connect(t, hosts[1], hosts[4])
	connect(t, hosts[2], hosts[3])
	connect(t, hosts[0], hosts[5])
	connect(t, hosts[5], hosts[6])
	connect(t, hosts[5], hosts[8])
	connect(t, hosts[6], hosts[7])
	connect(t, hosts[8], hosts[9])

	/*
		[0] -> [1] -> [2] -> [3]
		 |      L->[4]
		 v
		[5] -> [6] -> [7]
		 |
		 v
		[8] -> [9]
	*/

	var chs []*Subscription
	for _, ps := range psubs {
		ch, err := ps.Subscribe([]byte{0xf1, 0x22, 0xb0, 0x22})
		if err != nil {
			t.Fatal(err)
		}

		chs = append(chs, ch)
	}

	time.Sleep(time.Millisecond * 50)

	assertPeerLists(t, hosts, psubs[0], 1, 5)
	assertPeerLists(t, hosts, psubs[1], 0, 2, 4)
	assertPeerLists(t, hosts, psubs[2], 1, 3)

	checkMessageRouting(t, []byte{0xf1, 0x22, 0xb0, 0x22}, []*PubSub{psubs[9], psubs[3]}, chs)
}

func assertHasBitmasks(t *testing.T, ps *PubSub, expbitmasks ...string) {
	bitmasks := ps.GetBitmasks()
	sort.Strings(bitmasks)
	sort.Strings(expbitmasks)

	if len(bitmasks) != len(expbitmasks) {
		t.Fatalf("expected to have %v, but got %v", expbitmasks, bitmasks)
	}

	for i, v := range expbitmasks {
		if bitmasks[i] != v {
			t.Fatalf("expected %s but have %s", v, bitmasks[i])
		}
	}
}

func TestFloodSubPluggableProtocol(t *testing.T) {
	t.Run("multi-procol router acts like a hub", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()

		hosts := getNetHosts(t, ctx, 3)

		psubA := mustCreatePubSub(ctx, t, hosts[0], "/esh/floodsub", "/lsr/floodsub")
		psubB := mustCreatePubSub(ctx, t, hosts[1], "/esh/floodsub")
		psubC := mustCreatePubSub(ctx, t, hosts[2], "/lsr/floodsub")

		subA := mustSubscribe(t, psubA, []byte{0xf0, 0x0b, 0xa1, 0x20})
		defer subA.Cancel()

		subB := mustSubscribe(t, psubB, []byte{0xf0, 0x0b, 0xa1, 0x20})
		defer subB.Cancel()

		subC := mustSubscribe(t, psubC, []byte{0xf0, 0x0b, 0xa1, 0x20})
		defer subC.Cancel()

		// B --> A, C --> A
		connect(t, hosts[1], hosts[0])
		connect(t, hosts[2], hosts[0])

		time.Sleep(time.Millisecond * 100)

		psubC.Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, []byte([]byte{0xba, 0x12}))

		assertReceive(t, subA, []byte([]byte{0xba, 0x12}))
		assertReceive(t, subB, []byte([]byte{0xba, 0x12}))
		assertReceive(t, subC, []byte([]byte{0xba, 0x12}))
	})

	t.Run("won't talk to routers with no protocol overlap", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()

		hosts := getNetHosts(t, ctx, 2)

		psubA := mustCreatePubSub(ctx, t, hosts[0], "/esh/floodsub")
		psubB := mustCreatePubSub(ctx, t, hosts[1], "/lsr/floodsub")

		subA := mustSubscribe(t, psubA, []byte{0xf0, 0x0b, 0xa1, 0x20})
		defer subA.Cancel()

		subB := mustSubscribe(t, psubB, []byte{0xf0, 0x0b, 0xa1, 0x20})
		defer subB.Cancel()

		connect(t, hosts[1], hosts[0])

		time.Sleep(time.Millisecond * 100)

		psubA.Publish([]byte{0xf0, 0x0b, 0xa1, 0x20}, []byte([]byte{0xba, 0x12}))

		assertReceive(t, subA, []byte([]byte{0xba, 0x12}))

		pass := false
		select {
		case <-subB.ch:
			t.Fatal("different protocols: should not have received message")
		case <-time.After(time.Second * 1):
			pass = true
		}

		if !pass {
			t.Fatal("should have timed out waiting for message")
		}
	})
}

func mustCreatePubSub(ctx context.Context, t *testing.T, h host.Host, ps ...protocol.ID) *PubSub {
	psub, err := NewFloodsubWithProtocols(ctx, h, ps)
	if err != nil {
		t.Fatal(err)
	}

	return psub
}

func mustSubscribe(t *testing.T, ps *PubSub, bitmask []byte) *Subscription {
	sub, err := ps.Subscribe(bitmask)
	if err != nil {
		t.Fatal(err)
	}

	return sub
}

func TestSubReporting(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	host := getNetHosts(t, ctx, 1)[0]
	psub, err := NewFloodSub(ctx, host)
	if err != nil {
		t.Fatal(err)
	}

	fooSub, err := psub.Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}

	barSub, err := psub.Subscribe([]byte{0xba, 0x12})
	if err != nil {
		t.Fatal(err)
	}

	assertHasBitmasks(t, psub, string([]byte{0xf0, 0x00}), string([]byte{0xba, 0x12}))

	_, err = psub.Subscribe([]byte{0xba, 0x20})
	if err != nil {
		t.Fatal(err)
	}

	assertHasBitmasks(t, psub, string([]byte{0xf0, 0x00}), string([]byte{0xba, 0x12}), string([]byte{0xba, 0x20}))

	barSub.Cancel()
	assertHasBitmasks(t, psub, string([]byte{0xf0, 0x00}), string([]byte{0xba, 0x20}))
	fooSub.Cancel()
	assertHasBitmasks(t, psub, string([]byte{0xba, 0x20}))

	_, err = psub.Subscribe([]byte{0xf1, 0x24})
	if err != nil {
		t.Fatal(err)
	}

	assertHasBitmasks(t, psub, string([]byte{0xba, 0x20}), string([]byte{0xf1, 0x24}))
}

func TestPeerBitmaskReporting(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 4)
	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])
	connect(t, hosts[0], hosts[2])
	connect(t, hosts[0], hosts[3])

	_, err := psubs[1].Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}
	_, err = psubs[1].Subscribe([]byte{0xba, 0x12})
	if err != nil {
		t.Fatal(err)
	}
	_, err = psubs[1].Subscribe([]byte{0xba, 0x20})
	if err != nil {
		t.Fatal(err)
	}

	_, err = psubs[2].Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}
	_, err = psubs[2].Subscribe([]byte{0x01, 0xdf})
	if err != nil {
		t.Fatal(err)
	}

	_, err = psubs[3].Subscribe([]byte{0xba, 0x20})
	if err != nil {
		t.Fatal(err)
	}
	_, err = psubs[3].Subscribe([]byte{0x01, 0xdf})
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 200)

	peers := psubs[0].ListPeers([]byte{0x01, 0xdf})
	assertPeerList(t, peers, hosts[2].ID(), hosts[3].ID())

	peers = psubs[0].ListPeers([]byte{0xf0, 0x00})
	assertPeerList(t, peers, hosts[1].ID(), hosts[2].ID())

	peers = psubs[0].ListPeers([]byte{0xba, 0x20})
	assertPeerList(t, peers, hosts[1].ID(), hosts[3].ID())

	peers = psubs[0].ListPeers([]byte{0xba, 0x12})
	assertPeerList(t, peers, hosts[1].ID())
}

func TestSubscribeMultipleTimes(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])

	sub1, err := psubs[0].Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}
	sub2, err := psubs[0].Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}

	// make sure subscribing is finished by the time we publish
	time.Sleep(10 * time.Millisecond)

	psubs[1].Publish([]byte{0xf0, 0x00}, []byte([]byte{0xba, 0x12}))

	msg, err := sub1.Next(ctx)
	if err != nil {
		t.Fatalf("unexpected error: %v.", err)
	}

	data := msg.GetData()

	if !bytes.Equal(data, []byte{0xba, 0x12}) {
		t.Fatalf("data is %s, expected %s.", data, []byte{0xba, 0x12})
	}

	msg, err = sub2.Next(ctx)
	if err != nil {
		t.Fatalf("unexpected error: %v.", err)
	}
	data = msg.GetData()

	if !bytes.Equal(data, []byte{0xba, 0x12}) {
		t.Fatalf("data is %s, expected %s.", data, []byte{0xba, 0x12})
	}
}

func TestPeerDisconnect(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	psubs := getPubsubs(ctx, hosts)

	connect(t, hosts[0], hosts[1])

	_, err := psubs[0].Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}

	_, err = psubs[1].Subscribe([]byte{0xf0, 0x00})
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 300)

	peers := psubs[0].ListPeers([]byte{0xf0, 0x00})
	assertPeerList(t, peers, hosts[1].ID())
	for _, c := range hosts[1].Network().ConnsToPeer(hosts[0].ID()) {
		c.Close()
	}

	time.Sleep(time.Millisecond * 300)

	peers = psubs[0].ListPeers([]byte{0xf0, 0x00})
	assertPeerList(t, peers)
}

func assertPeerList(t *testing.T, peers []peer.ID, expected ...peer.ID) {
	sort.Sort(peer.IDSlice(peers))
	sort.Sort(peer.IDSlice(expected))

	if len(peers) != len(expected) {
		t.Fatalf("mismatch: %s != %s", peers, expected)
	}

	for i, p := range peers {
		if expected[i] != p {
			t.Fatalf("mismatch: %s != %s", peers, expected)
		}
	}
}

func TestWithNoSigning(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	psubs := getPubsubs(ctx, hosts, WithNoAuthor(), WithMessageIdFn(func(pmsg *pb.Message) string {
		// silly content-based test message-ID: just use the data as whole
		return base64.URLEncoding.EncodeToString(pmsg.Data)
	}))

	connect(t, hosts[0], hosts[1])

	bitmask := []byte{0xf0, 0x0b, 0xa1, 0x20}
	data := []byte("this is a message")

	sub, err := psubs[1].Subscribe(bitmask)
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 10)

	err = psubs[0].Publish(bitmask, data)
	if err != nil {
		t.Fatal(err)
	}

	msg, err := sub.Next(ctx)
	if err != nil {
		t.Fatal(err)
	}
	if msg.Signature != nil {
		t.Fatal("signature in message")
	}
	if msg.From != nil {
		t.Fatal("from in message")
	}
	if msg.Seqno != nil {
		t.Fatal("seqno in message")
	}
	if string(msg.Data) != string(data) {
		t.Fatalf("unexpected data: %s", string(msg.Data))
	}
}

func TestWithSigning(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	psubs := getPubsubs(ctx, hosts, WithStrictSignatureVerification(true))

	connect(t, hosts[0], hosts[1])

	bitmask := []byte{0xf0, 0x0b, 0xa1, 0x20}
	data := []byte("this is a message")

	sub, err := psubs[1].Subscribe(bitmask)
	if err != nil {
		t.Fatal(err)
	}

	time.Sleep(time.Millisecond * 10)

	err = psubs[0].Publish(bitmask, data)
	if err != nil {
		t.Fatal(err)
	}

	msg, err := sub.Next(ctx)
	if err != nil {
		t.Fatal(err)
	}
	if msg.Signature == nil {
		t.Fatal("no signature in message")
	}
	if msg.From == nil {
		t.Fatal("from not in message")
	}
	if msg.Seqno == nil {
		t.Fatal("seqno not in message")
	}
	if string(msg.Data) != string(data) {
		t.Fatalf("unexpected data: %s", string(msg.Data))
	}
}

func TestImproperlySignedMessageRejected(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	adversary := hosts[0]
	honestPeer := hosts[1]

	// The adversary enables signing, but disables verification to let through
	// an incorrectly signed message.
	adversaryPubSub := getPubsub(
		ctx,
		adversary,
		WithMessageSigning(true),
		WithStrictSignatureVerification(false),
	)
	honestPubSub := getPubsub(
		ctx,
		honestPeer,
		WithStrictSignatureVerification(true),
	)

	connect(t, adversary, honestPeer)

	var (
		bitmask          = []byte{0xf0, 0x0b, 0xa1, 0x20}
		correctMessage   = []byte("this is a correct message")
		incorrectMessage = []byte("this is the incorrect message")
	)

	adversarySubscription, err := adversaryPubSub.Subscribe(bitmask)
	if err != nil {
		t.Fatal(err)
	}
	honestPeerSubscription, err := honestPubSub.Subscribe(bitmask)
	if err != nil {
		t.Fatal(err)
	}
	time.Sleep(time.Millisecond * 50)

	// First the adversary sends the correct message.
	err = adversaryPubSub.Publish(bitmask, correctMessage)
	if err != nil {
		t.Fatal(err)
	}

	// Change the sign key for the adversarial peer, and send the second,
	// incorrectly signed, message.
	adversaryPubSub.signID = honestPubSub.signID
	adversaryPubSub.signKey = honestPubSub.host.Peerstore().PrivKey(honestPubSub.signID)
	err = adversaryPubSub.Publish(bitmask, incorrectMessage)
	if err != nil {
		t.Fatal(err)
	}

	var adversaryMessages []*Message
	adversaryContext, adversaryCancel := context.WithCancel(ctx)
	go func(ctx context.Context) {
		for {
			select {
			case <-ctx.Done():
				return
			default:
				msg, err := adversarySubscription.Next(ctx)
				if err != nil {
					return
				}
				adversaryMessages = append(adversaryMessages, msg)
			}
		}
	}(adversaryContext)

	<-time.After(1 * time.Second)
	adversaryCancel()

	// Ensure the adversary successfully publishes the incorrectly signed
	// message. If the adversary "sees" this, we successfully got through
	// their local validation.
	if len(adversaryMessages) != 2 {
		t.Fatalf("got %d messages, expected 2", len(adversaryMessages))
	}

	// the honest peer's validation process will drop the message;
	// next will never furnish the incorrect message.
	var honestPeerMessages []*Message
	honestPeerContext, honestPeerCancel := context.WithCancel(ctx)
	go func(ctx context.Context) {
		for {
			select {
			case <-ctx.Done():
				return
			default:
				msg, err := honestPeerSubscription.Next(ctx)
				if err != nil {
					return
				}
				honestPeerMessages = append(honestPeerMessages, msg)
			}
		}
	}(honestPeerContext)

	<-time.After(1 * time.Second)
	honestPeerCancel()

	if len(honestPeerMessages) != 1 {
		t.Fatalf("got %d messages, expected 1", len(honestPeerMessages))
	}
	if string(honestPeerMessages[0].GetData()) != string(correctMessage) {
		t.Fatalf(
			"got %s, expected message %s",
			honestPeerMessages[0].GetData(),
			correctMessage,
		)
	}
}

func TestMessageSender(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	bitmask := []byte{0xf0, 0x0b, 0xa1, 0x20}

	hosts := getNetHosts(t, ctx, 3)
	psubs := getPubsubs(ctx, hosts)

	var msgs []*Subscription
	for _, ps := range psubs {
		subch, err := ps.Subscribe(bitmask)
		if err != nil {
			t.Fatal(err)
		}

		msgs = append(msgs, subch)
	}

	connect(t, hosts[0], hosts[1])
	connect(t, hosts[1], hosts[2])

	time.Sleep(time.Millisecond * 100)

	for i := 0; i < 3; i++ {
		for j := 0; j < 100; j++ {
			msg := []byte(fmt.Sprintf("%d sent %d", i, j))

			psubs[i].Publish(bitmask, msg)

			for k, sub := range msgs {
				got, err := sub.Next(ctx)
				if err != nil {
					t.Fatal(sub.err)
				}
				if !bytes.Equal(msg, got.Data) {
					t.Fatal("got wrong message!")
				}

				var expectedHost int
				if i == k {
					expectedHost = i
				} else if k != 1 {
					expectedHost = 1
				} else {
					expectedHost = i
				}

				if got.ReceivedFrom != hosts[expectedHost].ID() {
					t.Fatal("got wrong message sender")
				}
			}
		}
	}
}

func TestConfigurableMaxMessageSize(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 10)

	// use a 4mb limit; default is 1mb; we'll test with a 2mb payload.
	psubs := getPubsubs(ctx, hosts, WithMaxMessageSize(1<<22))

	sparseConnect(t, hosts)
	time.Sleep(time.Millisecond * 100)

	bitmask := []byte{0xf0, 0x0b, 0xa1, 0x20}
	var subs []*Subscription
	for _, ps := range psubs {
		subch, err := ps.Subscribe(bitmask)
		if err != nil {
			t.Fatal(err)
		}
		subs = append(subs, subch)
	}

	// 2mb payload.
	msg := make([]byte, 1<<21)
	rand.Read(msg)
	err := psubs[0].Publish(bitmask, msg)
	if err != nil {
		t.Fatal(err)
	}

	// make sure that all peers received the message.
	for _, sub := range subs {
		got, err := sub.Next(ctx)
		if err != nil {
			t.Fatal(sub.err)
		}
		if !bytes.Equal(msg, got.Data) {
			t.Fatal("got wrong message!")
		}
	}

}

func TestAnnounceRetry(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hosts := getNetHosts(t, ctx, 2)
	ps := getPubsub(ctx, hosts[0])
	watcher := &announceWatcher{}
	hosts[1].SetStreamHandler(FloodSubID, watcher.handleStream)

	_, err := ps.Subscribe([]byte{0x7e, 0x57})
	if err != nil {
		t.Fatal(err)
	}

	// connect the watcher to the pubsub
	connect(t, hosts[0], hosts[1])

	// wait a bit for the first subscription to be emitted and trigger announce retry
	time.Sleep(100 * time.Millisecond)
	go ps.announceRetry(hosts[1].ID(), []byte{0x7e, 0x57}, true)

	// wait a bit for the subscription to propagate and ensure it was received twice
	time.Sleep(time.Second + 100*time.Millisecond)
	count := watcher.countSubs()
	if count != 2 {
		t.Fatalf("expected 2 subscription messages, but got %d", count)
	}
}

type announceWatcher struct {
	mx   sync.Mutex
	subs int
}

func (aw *announceWatcher) handleStream(s network.Stream) {
	defer s.Close()

	r := protoio.NewDelimitedReader(s, 1<<20)

	var rpc pb.RPC
	for {
		rpc.Reset()
		err := r.ReadMsg(&rpc)
		if err != nil {
			if err != io.EOF {
				s.Reset()
			}
			return
		}

		for _, sub := range rpc.GetSubscriptions() {
			if sub.GetSubscribe() && bytes.Equal(sub.GetBitmask(), []byte{0x7e, 0x57}) {
				aw.mx.Lock()
				aw.subs++
				aw.mx.Unlock()
			}
		}
	}
}

func (aw *announceWatcher) countSubs() int {
	aw.mx.Lock()
	defer aw.mx.Unlock()
	return aw.subs
}

func TestPubsubWithAssortedOptions(t *testing.T) {
	// this test uses assorted options that are not covered in other tests
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	hashMsgID := func(m *pb.Message) string {
		hash := sha256.Sum256(m.Data)
		return string(hash[:])
	}

	hosts := getNetHosts(t, ctx, 2)
	psubs := getPubsubs(ctx, hosts,
		WithMessageIdFn(hashMsgID),
		WithPeerOutboundQueueSize(10),
		WithMessageAuthor(""),
		WithBlacklist(NewMapBlacklist()))

	connect(t, hosts[0], hosts[1])

	var subs []*Subscription
	for _, ps := range psubs {
		sub, err := ps.Subscribe([]byte{0x7e, 0x57})
		if err != nil {
			t.Fatal(err)
		}
		subs = append(subs, sub)
	}

	time.Sleep(time.Second)

	for i := 0; i < 2; i++ {
		msg := []byte(fmt.Sprintf("message %d", i))
		psubs[i].Publish([]byte{0x7e, 0x57}, msg)

		for _, sub := range subs {
			assertReceive(t, sub, msg)
		}
	}
}

func TestWithInvalidMessageAuthor(t *testing.T) {
	// this test exercises the failure path in the WithMessageAuthor option
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	h := bhost.NewBlankHost(swarmt.GenSwarm(t))
	defer h.Close()
	_, err := NewFloodSub(ctx, h, WithMessageAuthor("bogotr0n"))
	if err == nil {
		t.Fatal("expected error")
	}
}

func TestPreconnectedNodes(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	// If this test fails it may hang so set a timeout
	ctx, cancel = context.WithTimeout(ctx, time.Second*10)
	defer cancel()

	// Create hosts
	h1 := bhost.NewBlankHost(swarmt.GenSwarm(t))
	h2 := bhost.NewBlankHost(swarmt.GenSwarm(t))
	defer h1.Close()
	defer h2.Close()

	opts := []Option{WithDiscovery(&dummyDiscovery{})}
	// Setup first PubSub
	p1, err := NewFloodSub(ctx, h1, opts...)
	if err != nil {
		t.Fatal(err)
	}

	// Connect the two hosts together
	connect(t, h2, h1)

	// Setup the second DHT
	p2, err := NewFloodSub(ctx, h2, opts...)
	if err != nil {
		t.Fatal(err)
	}

	// See if it works
	p2Bitmask, err := p2.Join([]byte{0x7e, 0x57})
	if err != nil {
		t.Fatal(err)
	}

	p1Bitmask, err := p1.Join([]byte{0x7e, 0x57})
	if err != nil {
		t.Fatal(err)
	}

	testPublish := func(publisher, receiver *Bitmask, msg []byte) {
		receiverSub, err := receiver.Subscribe()
		if err != nil {
			t.Fatal(err)
		}

		if err := publisher.Publish(ctx, msg, WithReadiness(MinBitmaskSize(1))); err != nil {
			t.Fatal(err)
		}

		m, err := receiverSub.Next(ctx)
		if err != nil {
			t.Fatal(err)
		}

		if receivedData := m.GetData(); !bytes.Equal(receivedData, msg) {
			t.Fatalf("expected message %v, got %v", msg, receivedData)
		}
	}

	// Test both directions since PubSub uses one directional streams
	testPublish(p1Bitmask, p2Bitmask, []byte("test1-to-2"))
	testPublish(p1Bitmask, p2Bitmask, []byte("test2-to-1"))
}

func TestDedupInboundStreams(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	h1 := bhost.NewBlankHost(swarmt.GenSwarm(t))
	h2 := bhost.NewBlankHost(swarmt.GenSwarm(t))
	defer h1.Close()
	defer h2.Close()

	_, err := NewFloodSub(ctx, h1)
	if err != nil {
		t.Fatal(err)
	}

	// Connect the two hosts together
	connect(t, h2, h1)

	// open a few streams and make sure all but the last one get reset
	s1, err := h2.NewStream(ctx, h1.ID(), FloodSubID)
	if err != nil {
		t.Fatal(err)
	}
	time.Sleep(100 * time.Millisecond)

	s2, err := h2.NewStream(ctx, h1.ID(), FloodSubID)
	if err != nil {
		t.Fatal(err)
	}
	time.Sleep(100 * time.Millisecond)

	s3, err := h2.NewStream(ctx, h1.ID(), FloodSubID)
	if err != nil {
		t.Fatal(err)
	}
	time.Sleep(100 * time.Millisecond)

	// check that s1 and s2 have been reset
	_, err = s1.Read([]byte{0})
	if err == nil {
		t.Fatal("expected s1 to be reset")
	}

	_, err = s2.Read([]byte{0})
	if err == nil {
		t.Fatal("expected s2 to be reset")
	}

	// check that s3 is readable and simply times out
	s3.SetReadDeadline(time.Now().Add(time.Millisecond))
	_, err = s3.Read([]byte{0})
	err2, ok := err.(interface{ Timeout() bool })
	if !ok || !err2.Timeout() {
		t.Fatal(err)
	}
}