package application import ( "bytes" "github.com/pkg/errors" "golang.org/x/crypto/sha3" "golang.org/x/sync/errgroup" "source.quilibrium.com/quilibrium/monorepo/nekryptology/pkg/core/curves" ) func ProcessRound( i []byte, idkKey curves.Scalar, round int, peers [][]byte, peerIdks []curves.Point, secrets []curves.Scalar, curve *curves.Curve, send func(int, []byte, []byte) error, recv func(int, []byte) ([]byte, error), seed []byte, ) ([]curves.Scalar, error) { roundPeers, roundIdks, isReceiver := GetPairings(i, round, peers, peerIdks) if roundPeers == nil { return nil, nil } var participants []Iterator if isReceiver { for _, roundIdk := range roundIdks { hashKeySeed := sha3.Sum256( append( roundIdk.Mul(idkKey).ToAffineCompressed(), seed..., ), ) participant := NewMultiplyReceiver(secrets, curve, hashKeySeed) participants = append(participants, participant) if err := participant.Init(); err != nil { return nil, errors.Wrap(err, "process round") } } } else { for _, roundIdk := range roundIdks { hashKeySeed := sha3.Sum256( append( roundIdk.Mul(idkKey).ToAffineCompressed(), seed..., ), ) participant := NewMultiplySender(secrets, curve, hashKeySeed) participants = append(participants, participant) if err := participant.Init(); err != nil { return nil, errors.Wrap(err, "process round") } } } eg := errgroup.Group{} eg.SetLimit(len(participants)) for j := range participants { j := j eg.Go(func() error { var msg []byte seq := 0 for !participants[j].IsDone() { var err error if isReceiver { msg, err = recv(seq, append(append([]byte{}, roundPeers[j]...), i...)) if err != nil { return err } } next, err := participants[j].Next(msg) if err != nil { return err } err = send(seq, append(append([]byte{}, i...), roundPeers[j]...), next) if err != nil { return err } if !isReceiver { msg, err = recv(seq, append(append([]byte{}, roundPeers[j]...), i...)) if err != nil { return err } } seq++ } return nil }) } if err := eg.Wait(); err != nil { return nil, errors.Wrap(err, "process round") } sums := make([]curves.Scalar, len(secrets)) for j := range sums { sums[j] = curve.Scalar.Zero() } for _, participant := range participants { scalars := participant.GetScalars() for j := range sums { sums[j] = sums[j].Add(scalars[j]) } } return sums, nil } func GetPairings(i []byte, round int, peers [][]byte, peerIdks []curves.Point) ( [][]byte, []curves.Point, bool, ) { n := len(peers) index := -1 for j := 0; j < n; j++ { if bytes.Equal([]byte(peers[j]), []byte(i)) { index = j + 1 break } } if index < 1 || index > n { return nil, nil, false // invalid input } power := uint64(n) >> round if power == 0 { return nil, nil, false // rounds exceeded } // Find the size of the subset for this round subsetSize := 1 << (round - 1) // Determine the subset that i belongs to subsetIndex := (index - 1) / subsetSize // If subsetIndex is odd, i's pairings are in the subset before it // If subsetIndex is even, i's pairings are in the subset after it complementarySubsetStart := 0 if subsetIndex%2 == 0 { complementarySubsetStart = (subsetIndex+1)*subsetSize + 1 } else { complementarySubsetStart = subsetIndex*subsetSize - subsetSize + 1 } // Generate the pairings pairings := make([][]byte, subsetSize) idks := make([]curves.Point, subsetSize) for j := 0; j < subsetSize; j++ { pairings[j] = peers[complementarySubsetStart+j-1] idks[j] = peerIdks[complementarySubsetStart+j-1] } return pairings, idks, (index - 1) < complementarySubsetStart }