ceremonyclient/go-libp2p-kad-dht/internal/net/message_manager.go
2024-10-12 11:55:17 -07:00

388 lines
9.0 KiB
Go

package net
import (
"bufio"
"context"
"fmt"
"io"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-msgio"
//lint:ignore SA1019 TODO migrate away from gogo pb
"github.com/libp2p/go-msgio/protoio"
"go.opencensus.io/stats"
"go.opencensus.io/tag"
"github.com/libp2p/go-libp2p-kad-dht/internal"
"github.com/libp2p/go-libp2p-kad-dht/metrics"
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
)
var dhtReadMessageTimeout = 10 * time.Second
// ErrReadTimeout is an error that occurs when no message is read within the timeout period.
var ErrReadTimeout = fmt.Errorf("timed out reading response")
var logger = logging.Logger("dht")
// messageSenderImpl is responsible for sending requests and messages to peers efficiently, including reuse of streams.
// It also tracks metrics for sent requests and messages.
type messageSenderImpl struct {
host host.Host // the network services we need
smlk sync.Mutex
strmap map[peer.ID]*peerMessageSender
protocols []protocol.ID
}
func NewMessageSenderImpl(h host.Host, protos []protocol.ID) pb.MessageSenderWithDisconnect {
return &messageSenderImpl{
host: h,
strmap: make(map[peer.ID]*peerMessageSender),
protocols: protos,
}
}
func (m *messageSenderImpl) OnDisconnect(ctx context.Context, p peer.ID) {
m.smlk.Lock()
defer m.smlk.Unlock()
ms, ok := m.strmap[p]
if !ok {
return
}
delete(m.strmap, p)
// Do this asynchronously as ms.lk can block for a while.
go func() {
if err := ms.lk.Lock(ctx); err != nil {
return
}
defer ms.lk.Unlock()
ms.invalidate()
}()
}
// SendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
func (m *messageSenderImpl) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
ms, err := m.messageSenderForPeer(ctx, p)
if err != nil {
stats.Record(ctx,
metrics.SentRequests.M(1),
metrics.SentRequestErrors.M(1),
)
logger.Debugw("request failed to open message sender", "error", err, "to", p)
return nil, err
}
start := time.Now()
rpmes, err := ms.SendRequest(ctx, pmes)
if err != nil {
stats.Record(ctx,
metrics.SentRequests.M(1),
metrics.SentRequestErrors.M(1),
)
logger.Debugw("request failed", "error", err, "to", p)
return nil, err
}
stats.Record(ctx,
metrics.SentRequests.M(1),
metrics.SentBytes.M(int64(pmes.Size())),
metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)),
)
m.host.Peerstore().RecordLatency(p, time.Since(start))
return rpmes, nil
}
// SendMessage sends out a message
func (m *messageSenderImpl) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
ms, err := m.messageSenderForPeer(ctx, p)
if err != nil {
stats.Record(ctx,
metrics.SentMessages.M(1),
metrics.SentMessageErrors.M(1),
)
logger.Debugw("message failed to open message sender", "error", err, "to", p)
return err
}
if err := ms.SendMessage(ctx, pmes); err != nil {
stats.Record(ctx,
metrics.SentMessages.M(1),
metrics.SentMessageErrors.M(1),
)
logger.Debugw("message failed", "error", err, "to", p)
return err
}
stats.Record(ctx,
metrics.SentMessages.M(1),
metrics.SentBytes.M(int64(pmes.Size())),
)
return nil
}
func (m *messageSenderImpl) messageSenderForPeer(ctx context.Context, p peer.ID) (*peerMessageSender, error) {
m.smlk.Lock()
ms, ok := m.strmap[p]
if ok {
m.smlk.Unlock()
return ms, nil
}
ms = &peerMessageSender{p: p, m: m, lk: internal.NewCtxMutex()}
m.strmap[p] = ms
m.smlk.Unlock()
if err := ms.prepOrInvalidate(ctx); err != nil {
m.smlk.Lock()
defer m.smlk.Unlock()
if msCur, ok := m.strmap[p]; ok {
// Changed. Use the new one, old one is invalid and
// not in the map so we can just throw it away.
if ms != msCur {
return msCur, nil
}
// Not changed, remove the now invalid stream from the
// map.
delete(m.strmap, p)
}
// Invalid but not in map. Must have been removed by a disconnect.
return nil, err
}
// All ready to go.
return ms, nil
}
// peerMessageSender is responsible for sending requests and messages to a particular peer
type peerMessageSender struct {
s network.Stream
r msgio.ReadCloser
lk internal.CtxMutex
p peer.ID
m *messageSenderImpl
invalid bool
singleMes int
}
// invalidate is called before this peerMessageSender is removed from the strmap.
// It prevents the peerMessageSender from being reused/reinitialized and then
// forgotten (leaving the stream open).
func (ms *peerMessageSender) invalidate() {
ms.invalid = true
if ms.s != nil {
_ = ms.s.Reset()
ms.s = nil
}
}
func (ms *peerMessageSender) prepOrInvalidate(ctx context.Context) error {
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()
if err := ms.prep(ctx); err != nil {
ms.invalidate()
return err
}
return nil
}
func (ms *peerMessageSender) prep(ctx context.Context) error {
if ms.invalid {
return fmt.Errorf("message sender has been invalidated")
}
if ms.s != nil {
return nil
}
// We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks
// one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for
// backwards compatibility reasons).
nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...)
if err != nil {
return err
}
ms.r = msgio.NewVarintReaderSize(nstr, network.MessageSizeMax)
ms.s = nstr
return nil
}
// streamReuseTries is the number of times we will try to reuse a stream to a
// given peer before giving up and reverting to the old one-message-per-stream
// behaviour.
const streamReuseTries = 3
func (ms *peerMessageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()
retry := false
for {
if err := ms.prep(ctx); err != nil {
return err
}
if err := ms.writeMsg(pmes); err != nil {
_ = ms.s.Reset()
ms.s = nil
if retry {
logger.Debugw("error writing message", "error", err)
return err
}
logger.Debugw("error writing message", "error", err, "retrying", true)
retry = true
continue
}
var err error
if ms.singleMes > streamReuseTries {
err = ms.s.Close()
ms.s = nil
} else if retry {
ms.singleMes++
}
return err
}
}
func (ms *peerMessageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
if err := ms.lk.Lock(ctx); err != nil {
return nil, err
}
defer ms.lk.Unlock()
retry := false
for {
if err := ms.prep(ctx); err != nil {
return nil, err
}
if err := ms.writeMsg(pmes); err != nil {
_ = ms.s.Reset()
ms.s = nil
if retry {
logger.Debugw("error writing message", "error", err)
return nil, err
}
logger.Debugw("error writing message", "error", err, "retrying", true)
retry = true
continue
}
mes := new(pb.Message)
if err := ms.ctxReadMsg(ctx, mes); err != nil {
_ = ms.s.Reset()
ms.s = nil
if err == context.Canceled {
// retry would be same error
return nil, err
}
if retry {
logger.Debugw("error reading message", "error", err)
return nil, err
}
logger.Debugw("error reading message", "error", err, "retrying", true)
retry = true
continue
}
var err error
if ms.singleMes > streamReuseTries {
err = ms.s.Close()
ms.s = nil
} else if retry {
ms.singleMes++
}
return mes, err
}
}
func (ms *peerMessageSender) writeMsg(pmes *pb.Message) error {
return WriteMsg(ms.s, pmes)
}
func (ms *peerMessageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error {
errc := make(chan error, 1)
go func(r msgio.ReadCloser) {
defer close(errc)
bytes, err := r.ReadMsg()
defer r.ReleaseMsg(bytes)
if err != nil {
errc <- err
return
}
errc <- mes.Unmarshal(bytes)
}(ms.r)
t := time.NewTimer(dhtReadMessageTimeout)
defer t.Stop()
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return ErrReadTimeout
}
}
// The Protobuf writer performs multiple small writes when writing a message.
// We need to buffer those writes, to make sure that we're not sending a new
// packet for every single write.
type bufferedDelimitedWriter struct {
*bufio.Writer
protoio.WriteCloser
}
var writerPool = sync.Pool{
New: func() interface{} {
w := bufio.NewWriter(nil)
return &bufferedDelimitedWriter{
Writer: w,
WriteCloser: protoio.NewDelimitedWriter(w),
}
},
}
func WriteMsg(w io.Writer, mes *pb.Message) error {
bw := writerPool.Get().(*bufferedDelimitedWriter)
bw.Reset(w)
err := bw.WriteMsg(mes)
if err == nil {
err = bw.Flush()
}
bw.Reset(nil)
writerPool.Put(bw)
return err
}
func (w *bufferedDelimitedWriter) Flush() error {
return w.Writer.Flush()
}