mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2025-01-24 14:45:17 +00:00
217 lines
4.8 KiB
Go
217 lines
4.8 KiB
Go
|
package rcmgr
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||
|
|
||
|
"github.com/multiformats/go-multiaddr"
|
||
|
manet "github.com/multiformats/go-multiaddr/net"
|
||
|
)
|
||
|
|
||
|
type Allowlist struct {
|
||
|
mu sync.RWMutex
|
||
|
// a simple structure of lists of networks. There is probably a faster way
|
||
|
// to check if an IP address is in this network than iterating over this
|
||
|
// list, but this is good enough for small numbers of networks (<1_000).
|
||
|
// Analyze the benchmark before trying to optimize this.
|
||
|
|
||
|
// Any peer with these IPs are allowed
|
||
|
allowedNetworks []*net.IPNet
|
||
|
|
||
|
// Only the specified peers can use these IPs
|
||
|
allowedPeerByNetwork map[peer.ID][]*net.IPNet
|
||
|
}
|
||
|
|
||
|
// WithAllowlistedMultiaddrs sets the multiaddrs to be in the allowlist
|
||
|
func WithAllowlistedMultiaddrs(mas []multiaddr.Multiaddr) Option {
|
||
|
return func(rm *resourceManager) error {
|
||
|
for _, ma := range mas {
|
||
|
err := rm.allowlist.Add(ma)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func newAllowlist() Allowlist {
|
||
|
return Allowlist{
|
||
|
allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, peer.ID, error) {
|
||
|
var ipString string
|
||
|
var mask string
|
||
|
var allowedPeerStr string
|
||
|
var allowedPeer peer.ID
|
||
|
var isIPV4 bool
|
||
|
|
||
|
multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
|
||
|
if c.Protocol().Code == multiaddr.P_IP4 || c.Protocol().Code == multiaddr.P_IP6 {
|
||
|
isIPV4 = c.Protocol().Code == multiaddr.P_IP4
|
||
|
ipString = c.Value()
|
||
|
}
|
||
|
if c.Protocol().Code == multiaddr.P_IPCIDR {
|
||
|
mask = c.Value()
|
||
|
}
|
||
|
if c.Protocol().Code == multiaddr.P_P2P {
|
||
|
allowedPeerStr = c.Value()
|
||
|
}
|
||
|
return ipString == "" || mask == "" || allowedPeerStr == ""
|
||
|
})
|
||
|
|
||
|
if ipString == "" {
|
||
|
return nil, allowedPeer, errors.New("missing ip address")
|
||
|
}
|
||
|
|
||
|
if allowedPeerStr != "" {
|
||
|
var err error
|
||
|
allowedPeer, err = peer.Decode(allowedPeerStr)
|
||
|
if err != nil {
|
||
|
return nil, allowedPeer, fmt.Errorf("failed to decode allowed peer: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if mask == "" {
|
||
|
ip := net.ParseIP(ipString)
|
||
|
if ip == nil {
|
||
|
return nil, allowedPeer, errors.New("invalid ip address")
|
||
|
}
|
||
|
var mask net.IPMask
|
||
|
if isIPV4 {
|
||
|
mask = net.CIDRMask(32, 32)
|
||
|
} else {
|
||
|
mask = net.CIDRMask(128, 128)
|
||
|
}
|
||
|
|
||
|
net := &net.IPNet{IP: ip, Mask: mask}
|
||
|
return net, allowedPeer, nil
|
||
|
}
|
||
|
|
||
|
_, ipnet, err := net.ParseCIDR(ipString + "/" + mask)
|
||
|
return ipnet, allowedPeer, err
|
||
|
|
||
|
}
|
||
|
|
||
|
// Add takes a multiaddr and adds it to the allowlist. The multiaddr should be
|
||
|
// an ip address of the peer with or without a `/p2p` protocol.
|
||
|
// e.g. /ip4/1.2.3.4/p2p/QmFoo, /ip4/1.2.3.4, and /ip4/1.2.3.0/ipcidr/24 are valid.
|
||
|
// /p2p/QmFoo is not valid.
|
||
|
func (al *Allowlist) Add(ma multiaddr.Multiaddr) error {
|
||
|
ipnet, allowedPeer, err := toIPNet(ma)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
al.mu.Lock()
|
||
|
defer al.mu.Unlock()
|
||
|
|
||
|
if allowedPeer != peer.ID("") {
|
||
|
// We have a peerID constraint
|
||
|
if al.allowedPeerByNetwork == nil {
|
||
|
al.allowedPeerByNetwork = make(map[peer.ID][]*net.IPNet)
|
||
|
}
|
||
|
al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
|
||
|
} else {
|
||
|
al.allowedNetworks = append(al.allowedNetworks, ipnet)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (al *Allowlist) Remove(ma multiaddr.Multiaddr) error {
|
||
|
ipnet, allowedPeer, err := toIPNet(ma)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
al.mu.Lock()
|
||
|
defer al.mu.Unlock()
|
||
|
|
||
|
ipNetList := al.allowedNetworks
|
||
|
|
||
|
if allowedPeer != "" {
|
||
|
// We have a peerID constraint
|
||
|
ipNetList = al.allowedPeerByNetwork[allowedPeer]
|
||
|
}
|
||
|
|
||
|
if ipNetList == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
i := len(ipNetList)
|
||
|
for i > 0 {
|
||
|
i--
|
||
|
if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) {
|
||
|
// swap remove
|
||
|
ipNetList[i] = ipNetList[len(ipNetList)-1]
|
||
|
ipNetList = ipNetList[:len(ipNetList)-1]
|
||
|
// We only remove one thing
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if allowedPeer != "" {
|
||
|
al.allowedPeerByNetwork[allowedPeer] = ipNetList
|
||
|
} else {
|
||
|
al.allowedNetworks = ipNetList
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (al *Allowlist) Allowed(ma multiaddr.Multiaddr) bool {
|
||
|
ip, err := manet.ToIP(ma)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
al.mu.RLock()
|
||
|
defer al.mu.RUnlock()
|
||
|
|
||
|
for _, network := range al.allowedNetworks {
|
||
|
if network.Contains(ip) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for _, allowedNetworks := range al.allowedPeerByNetwork {
|
||
|
for _, network := range allowedNetworks {
|
||
|
if network.Contains(ip) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (al *Allowlist) AllowedPeerAndMultiaddr(peerID peer.ID, ma multiaddr.Multiaddr) bool {
|
||
|
ip, err := manet.ToIP(ma)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
al.mu.RLock()
|
||
|
defer al.mu.RUnlock()
|
||
|
|
||
|
for _, network := range al.allowedNetworks {
|
||
|
if network.Contains(ip) {
|
||
|
// We found a match that isn't constrained by a peerID
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if expectedNetworks, ok := al.allowedPeerByNetwork[peerID]; ok {
|
||
|
for _, expectedNetwork := range expectedNetworks {
|
||
|
if expectedNetwork.Contains(ip) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|