package holepunch

import (
	"context"
	"errors"
	"fmt"
	"sync"
	"time"

	logging "github.com/ipfs/go-log/v2"
	"github.com/libp2p/go-libp2p/core/event"
	"github.com/libp2p/go-libp2p/core/host"
	"github.com/libp2p/go-libp2p/core/network"
	"github.com/libp2p/go-libp2p/core/peer"
	"github.com/libp2p/go-libp2p/core/protocol"
	"github.com/libp2p/go-libp2p/p2p/host/eventbus"
	"github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb"
	"github.com/libp2p/go-libp2p/p2p/protocol/identify"
	"github.com/libp2p/go-msgio/pbio"

	ma "github.com/multiformats/go-multiaddr"
)

// Protocol is the libp2p protocol for Hole Punching.
const Protocol protocol.ID = "/libp2p/dcutr"

var log = logging.Logger("p2p-holepunch")

// StreamTimeout is the timeout for the hole punch protocol stream.
var StreamTimeout = 1 * time.Minute

const (
	ServiceName = "libp2p.holepunch"

	maxMsgSize = 4 * 1024 // 4K
)

// ErrClosed is returned when the hole punching is closed
var ErrClosed = errors.New("hole punching service closing")

type Option func(*Service) error

// The Service runs on every node that supports the DCUtR protocol.
type Service struct {
	ctx       context.Context
	ctxCancel context.CancelFunc

	host host.Host
	ids  identify.IDService

	holePuncherMx sync.Mutex
	holePuncher   *holePuncher

	hasPublicAddrsChan chan struct{}

	tracer *tracer
	filter AddrFilter

	refCount sync.WaitGroup
}

// NewService creates a new service that can be used for hole punching
// The Service runs on all hosts that support the DCUtR protocol,
// no matter if they are behind a NAT / firewall or not.
// The Service handles DCUtR streams (which are initiated from the node behind
// a NAT / Firewall once we establish a connection to them through a relay.
func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) {
	if ids == nil {
		return nil, errors.New("identify service can't be nil")
	}

	ctx, cancel := context.WithCancel(context.Background())
	s := &Service{
		ctx:                ctx,
		ctxCancel:          cancel,
		host:               h,
		ids:                ids,
		hasPublicAddrsChan: make(chan struct{}),
	}

	for _, opt := range opts {
		if err := opt(s); err != nil {
			cancel()
			return nil, err
		}
	}
	s.tracer.Start()

	s.refCount.Add(1)
	go s.watchForPublicAddr()

	return s, nil
}

func (s *Service) watchForPublicAddr() {
	log.Debug("waiting until we have at least one public address", "peer", s.host.ID())

	// TODO: We should have an event here that fires when identify discovers a new
	// address (and when autonat confirms that address).
	// As we currently don't have an event like this, just check our observed addresses
	// regularly (exponential backoff starting at 250 ms, capped at 5s).
	duration := 250 * time.Millisecond
	const maxDuration = 5 * time.Second
	t := time.NewTimer(duration)
	for {
		if containsPublicAddr(s.ids.OwnObservedAddrs()) {
			log.Debug("Host now has a public address. Starting holepunch protocol.")
			s.host.SetStreamHandler(Protocol, s.handleNewStream)
			break
		}

		select {
		case <-s.ctx.Done():
			s.refCount.Done()
			t.Stop()
			return
		case <-t.C:
			duration *= 2
			if duration > maxDuration {
				duration = maxDuration
			}
			t.Reset(duration)
		}
	}

	// Only start the holePuncher if we're behind a NAT / firewall.
	sub, err := s.host.EventBus().Subscribe(&event.EvtLocalReachabilityChanged{}, eventbus.Name("holepunch"))
	if err != nil {
		log.Debugf("failed to subscripe to Reachability event: %s", err)
		s.refCount.Done()
		t.Stop()
		return
	}
	for {
		select {
		case <-s.ctx.Done():
			s.refCount.Done()
			t.Stop()
			sub.Close()
			return
		case e, ok := <-sub.Out():
			if !ok {
				s.refCount.Done()
				t.Stop()
				sub.Close()
				return
			}
			if e.(event.EvtLocalReachabilityChanged).Reachability != network.ReachabilityPrivate {
				continue
			}
			s.holePuncherMx.Lock()
			s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter)
			s.holePuncherMx.Unlock()
			close(s.hasPublicAddrsChan)

			s.refCount.Done()
			t.Stop()
			sub.Close()
			return
		}
	}
}

// Close closes the Hole Punch Service.
func (s *Service) Close() error {
	var err error
	s.holePuncherMx.Lock()
	if s.holePuncher != nil {
		err = s.holePuncher.Close()
	}
	s.holePuncherMx.Unlock()
	s.tracer.Close()
	s.host.RemoveStreamHandler(Protocol)
	s.ctxCancel()
	s.refCount.Wait()
	return err
}

func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, remoteAddrs []ma.Multiaddr, ownAddrs []ma.Multiaddr, err error) {
	// sanity check: a hole punch request should only come from peers behind a relay
	if !isRelayAddress(str.Conn().RemoteMultiaddr()) {
		return 0, nil, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr())
	}
	ownAddrs = removeRelayAddrs(s.ids.OwnObservedAddrs())
	if s.filter != nil {
		ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs)
	}

	// If we can't tell the peer where to dial us, there's no point in starting the hole punching.
	if len(ownAddrs) == 0 {
		return 0, nil, nil, errors.New("rejecting hole punch request, as we don't have any public addresses")
	}

	if err := str.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
		log.Debugf("error reserving memory for stream: %s", err)
		return 0, nil, nil, err
	}

	wr := pbio.NewDelimitedWriter(str)
	rd := pbio.NewDelimitedReader(str, maxMsgSize)

	// Read Connect message
	msg := new(pb.HolePunch)

	str.SetDeadline(time.Now().Add(StreamTimeout))

	if err := rd.ReadMsg(msg); err != nil {
		str.Scope().ReleaseMemory(maxMsgSize)
		return 0, nil, nil, fmt.Errorf("failed to read message from initiator: %w", err)
	}
	if t := msg.GetType(); t != pb.HolePunch_CONNECT {
		str.Scope().ReleaseMemory(maxMsgSize)
		return 0, nil, nil, fmt.Errorf("expected CONNECT message from initiator but got %d", t)
	}

	obsDial := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs))
	if s.filter != nil {
		obsDial = s.filter.FilterRemote(str.Conn().RemotePeer(), obsDial)
	}

	log.Debugw("received hole punch request", "peer", str.Conn().RemotePeer(), "addrs", obsDial)
	if len(obsDial) == 0 {
		str.Scope().ReleaseMemory(maxMsgSize)
		return 0, nil, nil, errors.New("expected CONNECT message to contain at least one address")
	}

	// Write CONNECT message
	msg.Reset()
	msg.Type = pb.HolePunch_CONNECT.Enum()
	msg.ObsAddrs = addrsToBytes(ownAddrs)
	tstart := time.Now()
	if err := wr.WriteMsg(msg); err != nil {
		str.Scope().ReleaseMemory(maxMsgSize)
		return 0, nil, nil, fmt.Errorf("failed to write CONNECT message to initiator: %w", err)
	}

	// Read SYNC message
	msg.Reset()
	if err := rd.ReadMsg(msg); err != nil {
		str.Scope().ReleaseMemory(maxMsgSize)
		return 0, nil, nil, fmt.Errorf("failed to read message from initiator: %w", err)
	}
	if t := msg.GetType(); t != pb.HolePunch_SYNC {
		str.Scope().ReleaseMemory(maxMsgSize)
		return 0, nil, nil, fmt.Errorf("expected SYNC message from initiator but got %d", t)
	}
	str.Scope().ReleaseMemory(maxMsgSize)
	return time.Since(tstart), obsDial, ownAddrs, nil
}

func (s *Service) handleNewStream(str network.Stream) {
	// Check directionality of the underlying connection.
	// Peer A receives an inbound connection from peer B.
	// Peer A opens a new hole punch stream to peer B.
	// Peer B receives this stream, calling this function.
	// Peer B sees the underlying connection as an outbound connection.
	if str.Conn().Stat().Direction == network.DirInbound {
		str.Reset()
		return
	}

	if err := str.Scope().SetService(ServiceName); err != nil {
		log.Debugf("error attaching stream to holepunch service: %s", err)
		str.Reset()
		return
	}

	rp := str.Conn().RemotePeer()
	rtt, addrs, ownAddrs, err := s.incomingHolePunch(str)
	if err != nil {
		s.tracer.ProtocolError(rp, err)
		log.Debugw("error handling holepunching stream from", "peer", rp, "error", err)
		str.Reset()
		return
	}
	str.Close()

	// Hole punch now by forcing a connect
	pi := peer.AddrInfo{
		ID:    rp,
		Addrs: addrs,
	}
	s.tracer.StartHolePunch(rp, addrs, rtt)
	log.Debugw("starting hole punch", "peer", rp)
	start := time.Now()
	s.tracer.HolePunchAttempt(pi.ID)
	err = holePunchConnect(s.ctx, s.host, pi, false)
	dt := time.Since(start)
	s.tracer.EndHolePunch(rp, dt, err)
	s.tracer.HolePunchFinished("receiver", 1, addrs, ownAddrs, getDirectConnection(s.host, rp))
}

// DirectConnect is only exposed for testing purposes.
// TODO: find a solution for this.
func (s *Service) DirectConnect(p peer.ID) error {
	<-s.hasPublicAddrsChan
	s.holePuncherMx.Lock()
	holePuncher := s.holePuncher
	s.holePuncherMx.Unlock()
	return holePuncher.DirectConnect(p)
}