mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2024-12-27 00:55:17 +00:00
359 lines
9.1 KiB
Go
359 lines
9.1 KiB
Go
|
package quicreuse
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/google/gopacket/routing"
|
||
|
"github.com/libp2p/go-netroute"
|
||
|
"github.com/quic-go/quic-go"
|
||
|
)
|
||
|
|
||
|
type refCountedQuicTransport interface {
|
||
|
LocalAddr() net.Addr
|
||
|
|
||
|
// Used to send packets directly around QUIC. Useful for hole punching.
|
||
|
WriteTo([]byte, net.Addr) (int, error)
|
||
|
|
||
|
Close() error
|
||
|
|
||
|
// count transport reference
|
||
|
DecreaseCount()
|
||
|
IncreaseCount()
|
||
|
|
||
|
Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
|
||
|
Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error)
|
||
|
}
|
||
|
|
||
|
type singleOwnerTransport struct {
|
||
|
quic.Transport
|
||
|
|
||
|
// Used to write packets directly around QUIC.
|
||
|
packetConn net.PacketConn
|
||
|
}
|
||
|
|
||
|
func (c *singleOwnerTransport) IncreaseCount() {}
|
||
|
func (c *singleOwnerTransport) DecreaseCount() {
|
||
|
c.Transport.Close()
|
||
|
}
|
||
|
|
||
|
func (c *singleOwnerTransport) LocalAddr() net.Addr {
|
||
|
return c.Transport.Conn.LocalAddr()
|
||
|
}
|
||
|
|
||
|
func (c *singleOwnerTransport) Close() error {
|
||
|
// TODO(when we drop support for go 1.19) use errors.Join
|
||
|
c.Transport.Close()
|
||
|
return c.packetConn.Close()
|
||
|
}
|
||
|
|
||
|
func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||
|
// Safe because we called quic.OptimizeConn ourselves.
|
||
|
return c.packetConn.WriteTo(b, addr)
|
||
|
}
|
||
|
|
||
|
// Constant. Defined as variables to simplify testing.
|
||
|
var (
|
||
|
garbageCollectInterval = 30 * time.Second
|
||
|
maxUnusedDuration = 10 * time.Second
|
||
|
)
|
||
|
|
||
|
type refcountedTransport struct {
|
||
|
quic.Transport
|
||
|
|
||
|
// Used to write packets directly around QUIC.
|
||
|
packetConn net.PacketConn
|
||
|
|
||
|
mutex sync.Mutex
|
||
|
refCount int
|
||
|
unusedSince time.Time
|
||
|
}
|
||
|
|
||
|
func (c *refcountedTransport) IncreaseCount() {
|
||
|
c.mutex.Lock()
|
||
|
c.refCount++
|
||
|
c.unusedSince = time.Time{}
|
||
|
c.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (c *refcountedTransport) Close() error {
|
||
|
// TODO(when we drop support for go 1.19) use errors.Join
|
||
|
c.Transport.Close()
|
||
|
return c.packetConn.Close()
|
||
|
}
|
||
|
|
||
|
func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||
|
// Safe because we called quic.OptimizeConn ourselves.
|
||
|
return c.packetConn.WriteTo(b, addr)
|
||
|
}
|
||
|
|
||
|
func (c *refcountedTransport) LocalAddr() net.Addr {
|
||
|
return c.Transport.Conn.LocalAddr()
|
||
|
}
|
||
|
|
||
|
func (c *refcountedTransport) DecreaseCount() {
|
||
|
c.mutex.Lock()
|
||
|
c.refCount--
|
||
|
if c.refCount == 0 {
|
||
|
c.unusedSince = time.Now()
|
||
|
}
|
||
|
c.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (c *refcountedTransport) ShouldGarbageCollect(now time.Time) bool {
|
||
|
c.mutex.Lock()
|
||
|
defer c.mutex.Unlock()
|
||
|
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
|
||
|
}
|
||
|
|
||
|
type reuse struct {
|
||
|
mutex sync.Mutex
|
||
|
|
||
|
closeChan chan struct{}
|
||
|
gcStopChan chan struct{}
|
||
|
|
||
|
routes routing.Router
|
||
|
unicast map[string] /* IP.String() */ map[int] /* port */ *refcountedTransport
|
||
|
// globalListeners contains transports that are listening on 0.0.0.0 / ::
|
||
|
globalListeners map[int]*refcountedTransport
|
||
|
// globalDialers contains transports that we've dialed out from. These transports are listening on 0.0.0.0 / ::
|
||
|
// On Dial, transports are reused from this map if no transport is available in the globalListeners
|
||
|
// On Listen, transports are reused from this map if the requested port is 0, and then moved to globalListeners
|
||
|
globalDialers map[int]*refcountedTransport
|
||
|
|
||
|
statelessResetKey *quic.StatelessResetKey
|
||
|
metricsTracer *metricsTracer
|
||
|
}
|
||
|
|
||
|
func newReuse(srk *quic.StatelessResetKey, mt *metricsTracer) *reuse {
|
||
|
r := &reuse{
|
||
|
unicast: make(map[string]map[int]*refcountedTransport),
|
||
|
globalListeners: make(map[int]*refcountedTransport),
|
||
|
globalDialers: make(map[int]*refcountedTransport),
|
||
|
closeChan: make(chan struct{}),
|
||
|
gcStopChan: make(chan struct{}),
|
||
|
statelessResetKey: srk,
|
||
|
metricsTracer: mt,
|
||
|
}
|
||
|
go r.gc()
|
||
|
return r
|
||
|
}
|
||
|
|
||
|
func (r *reuse) gc() {
|
||
|
defer func() {
|
||
|
r.mutex.Lock()
|
||
|
for _, tr := range r.globalListeners {
|
||
|
tr.Close()
|
||
|
}
|
||
|
for _, tr := range r.globalDialers {
|
||
|
tr.Close()
|
||
|
}
|
||
|
for _, trs := range r.unicast {
|
||
|
for _, tr := range trs {
|
||
|
tr.Close()
|
||
|
}
|
||
|
}
|
||
|
r.mutex.Unlock()
|
||
|
close(r.gcStopChan)
|
||
|
}()
|
||
|
ticker := time.NewTicker(garbageCollectInterval)
|
||
|
defer ticker.Stop()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-r.closeChan:
|
||
|
return
|
||
|
case <-ticker.C:
|
||
|
now := time.Now()
|
||
|
r.mutex.Lock()
|
||
|
for key, tr := range r.globalListeners {
|
||
|
if tr.ShouldGarbageCollect(now) {
|
||
|
tr.Close()
|
||
|
delete(r.globalListeners, key)
|
||
|
}
|
||
|
}
|
||
|
for key, tr := range r.globalDialers {
|
||
|
if tr.ShouldGarbageCollect(now) {
|
||
|
tr.Close()
|
||
|
delete(r.globalDialers, key)
|
||
|
}
|
||
|
}
|
||
|
for ukey, trs := range r.unicast {
|
||
|
for key, tr := range trs {
|
||
|
if tr.ShouldGarbageCollect(now) {
|
||
|
tr.Close()
|
||
|
delete(trs, key)
|
||
|
}
|
||
|
}
|
||
|
if len(trs) == 0 {
|
||
|
delete(r.unicast, ukey)
|
||
|
// If we've dropped all transports with a unicast binding,
|
||
|
// assume our routes may have changed.
|
||
|
if len(r.unicast) == 0 {
|
||
|
r.routes = nil
|
||
|
} else {
|
||
|
// Ignore the error, there's nothing we can do about
|
||
|
// it.
|
||
|
r.routes, _ = netroute.New()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
r.mutex.Unlock()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
|
||
|
var ip *net.IP
|
||
|
|
||
|
// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
|
||
|
// Otherwise, save some time.
|
||
|
|
||
|
r.mutex.Lock()
|
||
|
router := r.routes
|
||
|
r.mutex.Unlock()
|
||
|
|
||
|
if router != nil {
|
||
|
_, _, src, err := router.Route(raddr.IP)
|
||
|
if err == nil && !src.IsUnspecified() {
|
||
|
ip = &src
|
||
|
}
|
||
|
}
|
||
|
|
||
|
r.mutex.Lock()
|
||
|
defer r.mutex.Unlock()
|
||
|
|
||
|
tr, err := r.transportForDialLocked(network, ip)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
tr.IncreaseCount()
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcountedTransport, error) {
|
||
|
if source != nil {
|
||
|
// We already have at least one suitable transport...
|
||
|
if trs, ok := r.unicast[source.String()]; ok {
|
||
|
// ... we don't care which port we're dialing from. Just use the first.
|
||
|
for _, tr := range trs {
|
||
|
return tr, nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Use a transport listening on 0.0.0.0 (or ::).
|
||
|
// Again, we don't care about the port number.
|
||
|
for _, tr := range r.globalListeners {
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
// Use a transport we've previously dialed from
|
||
|
for _, tr := range r.globalDialers {
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
// We don't have a transport that we can use for dialing.
|
||
|
// Dial a new connection from a random port.
|
||
|
var addr *net.UDPAddr
|
||
|
switch network {
|
||
|
case "udp4":
|
||
|
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
|
||
|
case "udp6":
|
||
|
addr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
|
||
|
}
|
||
|
conn, err := listenAndOptimize(network, addr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
tr := &refcountedTransport{Transport: quic.Transport{
|
||
|
Conn: conn,
|
||
|
StatelessResetKey: r.statelessResetKey,
|
||
|
}, packetConn: conn}
|
||
|
if r.metricsTracer != nil {
|
||
|
tr.Transport.Tracer = r.metricsTracer
|
||
|
}
|
||
|
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
|
||
|
r.mutex.Lock()
|
||
|
defer r.mutex.Unlock()
|
||
|
|
||
|
// Check if we can reuse a transport we have already dialed out from.
|
||
|
// We reuse a transport from globalDialers when the requested port is 0 or the requested
|
||
|
// port is already in the globalDialers.
|
||
|
// If we are reusing a transport from globalDialers, we move the globalDialers entry to
|
||
|
// globalListeners
|
||
|
if laddr.IP.IsUnspecified() {
|
||
|
var rTr *refcountedTransport
|
||
|
var localAddr *net.UDPAddr
|
||
|
|
||
|
if laddr.Port == 0 {
|
||
|
// the requested port is 0, we can reuse any transport
|
||
|
for _, tr := range r.globalDialers {
|
||
|
rTr = tr
|
||
|
localAddr = rTr.LocalAddr().(*net.UDPAddr)
|
||
|
delete(r.globalDialers, localAddr.Port)
|
||
|
break
|
||
|
}
|
||
|
} else if _, ok := r.globalDialers[laddr.Port]; ok {
|
||
|
rTr = r.globalDialers[laddr.Port]
|
||
|
localAddr = rTr.LocalAddr().(*net.UDPAddr)
|
||
|
delete(r.globalDialers, localAddr.Port)
|
||
|
}
|
||
|
// found a match
|
||
|
if rTr != nil {
|
||
|
rTr.IncreaseCount()
|
||
|
r.globalListeners[localAddr.Port] = rTr
|
||
|
return rTr, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
conn, err := listenAndOptimize(network, laddr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||
|
tr := &refcountedTransport{Transport: quic.Transport{
|
||
|
Conn: conn,
|
||
|
StatelessResetKey: r.statelessResetKey,
|
||
|
}, packetConn: conn}
|
||
|
if r.metricsTracer != nil {
|
||
|
tr.Transport.Tracer = r.metricsTracer
|
||
|
}
|
||
|
|
||
|
tr.IncreaseCount()
|
||
|
|
||
|
// Deal with listen on a global address
|
||
|
if localAddr.IP.IsUnspecified() {
|
||
|
// The kernel already checked that the laddr is not already listen
|
||
|
// so we need not check here (when we create ListenUDP).
|
||
|
r.globalListeners[localAddr.Port] = tr
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
// Deal with listen on a unicast address
|
||
|
if _, ok := r.unicast[localAddr.IP.String()]; !ok {
|
||
|
r.unicast[localAddr.IP.String()] = make(map[int]*refcountedTransport)
|
||
|
// Assume the system's routes may have changed if we're adding a new listener.
|
||
|
// Ignore the error, there's nothing we can do.
|
||
|
r.routes, _ = netroute.New()
|
||
|
}
|
||
|
|
||
|
// The kernel already checked that the laddr is not already listen
|
||
|
// so we need not check here (when we create ListenUDP).
|
||
|
r.unicast[localAddr.IP.String()][localAddr.Port] = tr
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
func (r *reuse) Close() error {
|
||
|
close(r.closeChan)
|
||
|
<-r.gcStopChan
|
||
|
return nil
|
||
|
}
|