ceremonyclient/go-libp2p/p2p/host/peerstore/pstoremem/protobook.go

197 lines
3.8 KiB
Go
Raw Permalink Normal View History

2023-08-21 03:50:38 +00:00
package pstoremem
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/peer"
pstore "github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
)
type protoSegment struct {
sync.RWMutex
protocols map[peer.ID]map[protocol.ID]struct{}
}
type protoSegments [256]*protoSegment
func (s *protoSegments) get(p peer.ID) *protoSegment {
2024-06-08 11:32:45 +00:00
return s[p[len(p)-1]]
2023-08-21 03:50:38 +00:00
}
var errTooManyProtocols = errors.New("too many protocols")
type memoryProtoBook struct {
segments protoSegments
maxProtos int
lk sync.RWMutex
interned map[protocol.ID]protocol.ID
}
var _ pstore.ProtoBook = (*memoryProtoBook)(nil)
type ProtoBookOption func(book *memoryProtoBook) error
func WithMaxProtocols(num int) ProtoBookOption {
return func(pb *memoryProtoBook) error {
pb.maxProtos = num
return nil
}
}
func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) {
pb := &memoryProtoBook{
interned: make(map[protocol.ID]protocol.ID, 256),
segments: func() (ret protoSegments) {
for i := range ret {
ret[i] = &protoSegment{
protocols: make(map[peer.ID]map[protocol.ID]struct{}),
}
}
return ret
}(),
maxProtos: 1024,
}
for _, opt := range opts {
if err := opt(pb); err != nil {
return nil, err
}
}
return pb, nil
}
func (pb *memoryProtoBook) internProtocol(proto protocol.ID) protocol.ID {
// check if it is interned with the read lock
pb.lk.RLock()
interned, ok := pb.interned[proto]
pb.lk.RUnlock()
if ok {
return interned
}
// intern with the write lock
pb.lk.Lock()
// check again in case it got interned in between locks
interned, ok = pb.interned[proto]
if ok {
pb.lk.Unlock()
2023-08-21 03:50:38 +00:00
return interned
}
pb.interned[proto] = proto
pb.lk.Unlock()
2023-08-21 03:50:38 +00:00
return proto
}
func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error {
if len(protos) > pb.maxProtos {
return errTooManyProtocols
}
newprotos := make(map[protocol.ID]struct{}, len(protos))
for _, proto := range protos {
newprotos[pb.internProtocol(proto)] = struct{}{}
}
s := pb.segments.get(p)
s.Lock()
s.protocols[p] = newprotos
s.Unlock()
return nil
}
func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
protomap, ok := s.protocols[p]
if !ok {
protomap = make(map[protocol.ID]struct{})
s.protocols[p] = protomap
}
if len(protomap)+len(protos) > pb.maxProtos {
s.Unlock()
2023-08-21 03:50:38 +00:00
return errTooManyProtocols
}
for _, proto := range protos {
protomap[pb.internProtocol(proto)] = struct{}{}
}
s.Unlock()
2023-08-21 03:50:38 +00:00
return nil
}
func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
out := make([]protocol.ID, 0, len(s.protocols[p]))
for k := range s.protocols[p] {
out = append(out, k)
}
s.RUnlock()
2023-08-21 03:50:38 +00:00
return out, nil
}
func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) error {
s := pb.segments.get(p)
s.Lock()
protomap, ok := s.protocols[p]
if !ok {
// nothing to remove.
s.Unlock()
2023-08-21 03:50:38 +00:00
return nil
}
for _, proto := range protos {
delete(protomap, pb.internProtocol(proto))
}
s.Unlock()
2023-08-21 03:50:38 +00:00
return nil
}
func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...protocol.ID) ([]protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
out := make([]protocol.ID, 0, len(protos))
for _, proto := range protos {
if _, ok := s.protocols[p][proto]; ok {
out = append(out, proto)
}
}
s.RUnlock()
2023-08-21 03:50:38 +00:00
return out, nil
}
func (pb *memoryProtoBook) FirstSupportedProtocol(p peer.ID, protos ...protocol.ID) (protocol.ID, error) {
s := pb.segments.get(p)
s.RLock()
for _, proto := range protos {
if _, ok := s.protocols[p][proto]; ok {
s.RUnlock()
2023-08-21 03:50:38 +00:00
return proto, nil
}
}
s.RUnlock()
2023-08-21 03:50:38 +00:00
return "", nil
}
func (pb *memoryProtoBook) RemovePeer(p peer.ID) {
s := pb.segments.get(p)
s.Lock()
delete(s.protocols, p)
s.Unlock()
}