mirror of
https://source.quilibrium.com/quilibrium/ceremonyclient.git
synced 2025-01-12 16:55:18 +00:00
299 lines
5.8 KiB
Go
299 lines
5.8 KiB
Go
|
package mocknet
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net"
|
||
|
"strconv"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/libp2p/go-libp2p/core/network"
|
||
|
"github.com/libp2p/go-libp2p/core/protocol"
|
||
|
)
|
||
|
|
||
|
var streamCounter atomic.Int64
|
||
|
|
||
|
// stream implements network.Stream
|
||
|
type stream struct {
|
||
|
rstream *stream
|
||
|
conn *conn
|
||
|
id int64
|
||
|
|
||
|
write *io.PipeWriter
|
||
|
read *io.PipeReader
|
||
|
toDeliver chan *transportObject
|
||
|
|
||
|
reset chan struct{}
|
||
|
close chan struct{}
|
||
|
closed chan struct{}
|
||
|
|
||
|
writeErr error
|
||
|
|
||
|
protocol atomic.Pointer[protocol.ID]
|
||
|
stat network.Stats
|
||
|
}
|
||
|
|
||
|
var ErrClosed = errors.New("stream closed")
|
||
|
|
||
|
type transportObject struct {
|
||
|
msg []byte
|
||
|
arrivalTime time.Time
|
||
|
}
|
||
|
|
||
|
func newStreamPair() (*stream, *stream) {
|
||
|
ra, wb := io.Pipe()
|
||
|
rb, wa := io.Pipe()
|
||
|
|
||
|
sa := newStream(wa, ra, network.DirOutbound)
|
||
|
sb := newStream(wb, rb, network.DirInbound)
|
||
|
sa.rstream = sb
|
||
|
sb.rstream = sa
|
||
|
return sa, sb
|
||
|
}
|
||
|
|
||
|
func newStream(w *io.PipeWriter, r *io.PipeReader, dir network.Direction) *stream {
|
||
|
s := &stream{
|
||
|
read: r,
|
||
|
write: w,
|
||
|
id: streamCounter.Add(1),
|
||
|
reset: make(chan struct{}, 1),
|
||
|
close: make(chan struct{}, 1),
|
||
|
closed: make(chan struct{}),
|
||
|
toDeliver: make(chan *transportObject),
|
||
|
stat: network.Stats{Direction: dir},
|
||
|
}
|
||
|
|
||
|
go s.transport()
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
// How to handle errors with writes?
|
||
|
func (s *stream) Write(p []byte) (n int, err error) {
|
||
|
l := s.conn.link
|
||
|
delay := l.GetLatency() + l.RateLimit(len(p))
|
||
|
t := time.Now().Add(delay)
|
||
|
|
||
|
// Copy it.
|
||
|
cpy := make([]byte, len(p))
|
||
|
copy(cpy, p)
|
||
|
|
||
|
select {
|
||
|
case <-s.closed: // bail out if we're closing.
|
||
|
return 0, s.writeErr
|
||
|
case s.toDeliver <- &transportObject{msg: cpy, arrivalTime: t}:
|
||
|
}
|
||
|
return len(p), nil
|
||
|
}
|
||
|
|
||
|
func (s *stream) ID() string {
|
||
|
return strconv.FormatInt(s.id, 10)
|
||
|
}
|
||
|
|
||
|
func (s *stream) Protocol() protocol.ID {
|
||
|
p := s.protocol.Load()
|
||
|
if p == nil {
|
||
|
return ""
|
||
|
}
|
||
|
return *p
|
||
|
}
|
||
|
|
||
|
func (s *stream) Stat() network.Stats {
|
||
|
return s.stat
|
||
|
}
|
||
|
|
||
|
func (s *stream) SetProtocol(proto protocol.ID) error {
|
||
|
s.protocol.Store(&proto)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *stream) CloseWrite() error {
|
||
|
select {
|
||
|
case s.close <- struct{}{}:
|
||
|
default:
|
||
|
}
|
||
|
<-s.closed
|
||
|
if s.writeErr != ErrClosed {
|
||
|
return s.writeErr
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *stream) CloseRead() error {
|
||
|
return s.read.CloseWithError(ErrClosed)
|
||
|
}
|
||
|
|
||
|
func (s *stream) Close() error {
|
||
|
_ = s.CloseRead()
|
||
|
return s.CloseWrite()
|
||
|
}
|
||
|
|
||
|
func (s *stream) Reset() error {
|
||
|
// Cancel any pending reads/writes with an error.
|
||
|
s.write.CloseWithError(network.ErrReset)
|
||
|
s.read.CloseWithError(network.ErrReset)
|
||
|
|
||
|
select {
|
||
|
case s.reset <- struct{}{}:
|
||
|
default:
|
||
|
}
|
||
|
<-s.closed
|
||
|
|
||
|
// No meaningful error case here.
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *stream) teardown() {
|
||
|
// at this point, no streams are writing.
|
||
|
s.conn.removeStream(s)
|
||
|
|
||
|
// Mark as closed.
|
||
|
close(s.closed)
|
||
|
}
|
||
|
|
||
|
func (s *stream) Conn() network.Conn {
|
||
|
return s.conn
|
||
|
}
|
||
|
|
||
|
func (s *stream) SetDeadline(t time.Time) error {
|
||
|
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||
|
}
|
||
|
|
||
|
func (s *stream) SetReadDeadline(t time.Time) error {
|
||
|
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||
|
}
|
||
|
|
||
|
func (s *stream) SetWriteDeadline(t time.Time) error {
|
||
|
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||
|
}
|
||
|
|
||
|
func (s *stream) Read(b []byte) (int, error) {
|
||
|
return s.read.Read(b)
|
||
|
}
|
||
|
|
||
|
// transport will grab message arrival times, wait until that time, and
|
||
|
// then write the message out when it is scheduled to arrive
|
||
|
func (s *stream) transport() {
|
||
|
defer s.teardown()
|
||
|
|
||
|
bufsize := 256
|
||
|
buf := new(bytes.Buffer)
|
||
|
timer := time.NewTimer(0)
|
||
|
if !timer.Stop() {
|
||
|
select {
|
||
|
case <-timer.C:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// cleanup
|
||
|
defer timer.Stop()
|
||
|
|
||
|
// writeBuf writes the contents of buf through to the s.Writer.
|
||
|
// done only when arrival time makes sense.
|
||
|
drainBuf := func() error {
|
||
|
if buf.Len() > 0 {
|
||
|
_, err := s.write.Write(buf.Bytes())
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
buf.Reset()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// deliverOrWait is a helper func that processes
|
||
|
// an incoming packet. it waits until the arrival time,
|
||
|
// and then writes things out.
|
||
|
deliverOrWait := func(o *transportObject) error {
|
||
|
buffered := len(o.msg) + buf.Len()
|
||
|
|
||
|
// Yes, we can end up extending a timer multiple times if we
|
||
|
// keep on making small writes but that shouldn't be too much of an
|
||
|
// issue. Fixing that would be painful.
|
||
|
if !timer.Stop() {
|
||
|
// FIXME: So, we *shouldn't* need to do this but we hang
|
||
|
// here if we don't... Go bug?
|
||
|
select {
|
||
|
case <-timer.C:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
delay := time.Until(o.arrivalTime)
|
||
|
if delay >= 0 {
|
||
|
timer.Reset(delay)
|
||
|
} else {
|
||
|
timer.Reset(0)
|
||
|
}
|
||
|
|
||
|
if buffered >= bufsize {
|
||
|
select {
|
||
|
case <-timer.C:
|
||
|
case <-s.reset:
|
||
|
select {
|
||
|
case s.reset <- struct{}{}:
|
||
|
default:
|
||
|
}
|
||
|
return network.ErrReset
|
||
|
}
|
||
|
if err := drainBuf(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
// write this message.
|
||
|
_, err := s.write.Write(o.msg)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
buf.Write(o.msg)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
// Reset takes precedent.
|
||
|
select {
|
||
|
case <-s.reset:
|
||
|
s.writeErr = network.ErrReset
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-s.reset:
|
||
|
s.writeErr = network.ErrReset
|
||
|
return
|
||
|
case <-s.close:
|
||
|
if err := drainBuf(); err != nil {
|
||
|
s.cancelWrite(err)
|
||
|
return
|
||
|
}
|
||
|
s.writeErr = s.write.Close()
|
||
|
if s.writeErr == nil {
|
||
|
s.writeErr = ErrClosed
|
||
|
}
|
||
|
return
|
||
|
case o := <-s.toDeliver:
|
||
|
if err := deliverOrWait(o); err != nil {
|
||
|
s.cancelWrite(err)
|
||
|
return
|
||
|
}
|
||
|
case <-timer.C: // ok, due to write it out.
|
||
|
if err := drainBuf(); err != nil {
|
||
|
s.cancelWrite(err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *stream) Scope() network.StreamScope {
|
||
|
return &network.NullScope{}
|
||
|
}
|
||
|
|
||
|
func (s *stream) cancelWrite(err error) {
|
||
|
s.write.CloseWithError(err)
|
||
|
s.writeErr = err
|
||
|
}
|