ceremonyclient/go-libp2p/p2p/transport/webrtc/connection.go

256 lines
7.1 KiB
Go
Raw Permalink Normal View History

2024-06-08 11:32:45 +00:00
package libp2pwebrtc
import (
"context"
"errors"
"fmt"
"math"
"net"
"sync"
"sync/atomic"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
)
var _ tpt.CapableConn = &connection{}
const maxAcceptQueueLen = 256
type errConnectionTimeout struct{}
var _ net.Error = &errConnectionTimeout{}
func (errConnectionTimeout) Error() string { return "connection timeout" }
func (errConnectionTimeout) Timeout() bool { return true }
func (errConnectionTimeout) Temporary() bool { return false }
type dataChannel struct {
stream datachannel.ReadWriteCloser
channel *webrtc.DataChannel
}
type connection struct {
pc *webrtc.PeerConnection
transport *WebRTCTransport
scope network.ConnManagementScope
closeOnce sync.Once
closeErr error
localPeer peer.ID
localMultiaddr ma.Multiaddr
remotePeer peer.ID
remoteKey ic.PubKey
remoteMultiaddr ma.Multiaddr
m sync.Mutex
streams map[uint16]*stream
nextStreamID atomic.Int32
acceptQueue chan dataChannel
ctx context.Context
cancel context.CancelFunc
}
func newConnection(
direction network.Direction,
pc *webrtc.PeerConnection,
transport *WebRTCTransport,
scope network.ConnManagementScope,
localPeer peer.ID,
localMultiaddr ma.Multiaddr,
remotePeer peer.ID,
remoteKey ic.PubKey,
remoteMultiaddr ma.Multiaddr,
incomingDataChannels chan dataChannel,
) (*connection, error) {
ctx, cancel := context.WithCancel(context.Background())
c := &connection{
pc: pc,
transport: transport,
scope: scope,
localPeer: localPeer,
localMultiaddr: localMultiaddr,
remotePeer: remotePeer,
remoteKey: remoteKey,
remoteMultiaddr: remoteMultiaddr,
ctx: ctx,
cancel: cancel,
streams: make(map[uint16]*stream),
acceptQueue: incomingDataChannels,
}
switch direction {
case network.DirInbound:
c.nextStreamID.Store(1)
case network.DirOutbound:
// stream ID 0 is used for the Noise handshake stream
c.nextStreamID.Store(2)
}
pc.OnConnectionStateChange(c.onConnectionStateChange)
return c, nil
}
// ConnState implements transport.CapableConn
func (c *connection) ConnState() network.ConnectionState {
return network.ConnectionState{Transport: "webrtc-direct"}
}
// Close closes the underlying peerconnection.
func (c *connection) Close() error {
c.closeOnce.Do(func() { c.closeWithError(errors.New("connection closed")) })
return nil
}
// closeWithError is used to Close the connection when the underlying DTLS connection fails
func (c *connection) closeWithError(err error) {
c.closeErr = err
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
// closing peerconnection will close the datachannels associated with the streams
c.pc.Close()
c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, s := range streams {
s.closeForShutdown(err)
}
c.scope.Done()
}
func (c *connection) IsClosed() bool {
return c.ctx.Err() != nil
}
func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error) {
if c.IsClosed() {
return nil, c.closeErr
}
id := c.nextStreamID.Add(2) - 2
if id > math.MaxUint16 {
return nil, errors.New("exhausted stream ID space")
}
streamID := uint16(id)
dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID})
if err != nil {
return nil, err
}
rwc, err := c.detachChannel(ctx, dc)
if err != nil {
dc.Close()
return nil, fmt.Errorf("detach channel failed for stream(%d): %w", streamID, err)
}
str := newStream(dc, rwc, func() { c.removeStream(streamID) })
if err := c.addStream(str); err != nil {
str.Reset()
return nil, fmt.Errorf("failed to add stream(%d) to connection: %w", streamID, err)
}
return str, nil
}
func (c *connection) AcceptStream() (network.MuxedStream, error) {
select {
case <-c.ctx.Done():
return nil, c.closeErr
case dc := <-c.acceptQueue:
str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) })
if err := c.addStream(str); err != nil {
str.Reset()
return nil, err
}
return str, nil
}
}
func (c *connection) LocalPeer() peer.ID { return c.localPeer }
func (c *connection) RemotePeer() peer.ID { return c.remotePeer }
func (c *connection) RemotePublicKey() ic.PubKey { return c.remoteKey }
func (c *connection) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr }
func (c *connection) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr }
func (c *connection) Scope() network.ConnScope { return c.scope }
func (c *connection) Transport() tpt.Transport { return c.transport }
func (c *connection) addStream(str *stream) error {
c.m.Lock()
defer c.m.Unlock()
if c.streams == nil {
return c.closeErr
}
if _, ok := c.streams[str.id]; ok {
return errors.New("stream ID already exists")
}
c.streams[str.id] = str
return nil
}
func (c *connection) removeStream(id uint16) {
c.m.Lock()
defer c.m.Unlock()
delete(c.streams, id)
}
func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
c.closeOnce.Do(func() {
c.closeWithError(errConnectionTimeout{})
})
}
}
// detachChannel detaches an outgoing channel by taking into account the context
// passed to `OpenStream` as well the closure of the underlying peerconnection
//
// The underlying SCTP stream for a datachannel implements a net.Conn interface.
// However, the datachannel creates a goroutine which continuously reads from
// the SCTP stream and surfaces the data using an OnMessage callback.
//
// The actual abstractions are as follows: webrtc.DataChannel
// wraps pion.DataChannel, which wraps sctp.Stream.
//
// The goroutine for reading, Detach method,
// and the OnMessage callback are present at the webrtc.DataChannel level.
// Detach provides us abstracted access to the underlying pion.DataChannel,
// which allows us to issue Read calls to the datachannel.
// This was desired because it was not feasible to introduce backpressure
// with the OnMessage callbacks. The tradeoff is a change in the semantics of
// the OnOpen callback, and having to force close Read locally.
func (c *connection) detachChannel(ctx context.Context, dc *webrtc.DataChannel) (datachannel.ReadWriteCloser, error) {
done := make(chan struct{})
var rwc datachannel.ReadWriteCloser
var err error
// OnOpen will return immediately for detached datachannels
// refer: https://github.com/pion/webrtc/blob/7ab3174640b3ce15abebc2516a2ca3939b5f105f/datachannel.go#L278-L282
dc.OnOpen(func() {
rwc, err = dc.Detach()
// this is safe since the function should return instantly if the peerconnection is closed
close(done)
})
select {
case <-c.ctx.Done():
return nil, c.closeErr
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return rwc, err
}
}