add simple rdf + tr

This commit is contained in:
Cassandra Heart 2024-06-21 01:28:02 -05:00
parent 310bd6fc11
commit a428fe51fd
No known key found for this signature in database
GPG Key ID: 6352152859385958
6 changed files with 2001 additions and 0 deletions

View File

@ -0,0 +1,337 @@
package channel
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"github.com/pkg/errors"
"source.quilibrium.com/quilibrium/monorepo/nekryptology/pkg/core/curves"
)
type Feldman struct {
threshold int
total int
id int
fragsForCounterparties map[int][]byte
fragsFromCounterparties map[int]curves.Scalar
zkpok curves.Scalar
secret curves.Scalar
scalar curves.Scalar
generator curves.Point
publicKey curves.Point
point curves.Point
randomCommitmentPoint curves.Point
round FeldmanRound
zkcommitsFromCounterparties map[int][]byte
pointsFromCounterparties map[int]curves.Point
curve curves.Curve
}
type FeldmanReveal struct {
Point []byte
RandomCommitmentPoint []byte
ZKPoK []byte
}
var ErrWrongRound = errors.New("wrong round for feldman")
type FeldmanRound int
const (
FELDMAN_ROUND_UNINITIALIZED = FeldmanRound(0)
FELDMAN_ROUND_INITIALIZED = FeldmanRound(1)
FELDMAN_ROUND_COMMITTED = FeldmanRound(2)
FELDMAN_ROUND_REVEALED = FeldmanRound(3)
FELDMAN_ROUND_RECONSTRUCTED = FeldmanRound(4)
)
func NewFeldman(
threshold, total, id int,
secret curves.Scalar,
curve curves.Curve,
generator curves.Point,
) (*Feldman, error) {
return &Feldman{
threshold: threshold,
total: total,
id: id,
fragsForCounterparties: make(map[int][]byte),
fragsFromCounterparties: make(map[int]curves.Scalar),
zkpok: nil,
secret: secret,
scalar: nil,
generator: generator,
publicKey: secret.Point().Generator(),
point: secret.Point().Generator(),
round: FELDMAN_ROUND_UNINITIALIZED,
zkcommitsFromCounterparties: make(map[int][]byte),
pointsFromCounterparties: make(map[int]curves.Point),
curve: curve,
}, nil
}
func (f *Feldman) SamplePolynomial() error {
if f.round != FELDMAN_ROUND_UNINITIALIZED {
return errors.Wrap(ErrWrongRound, "sample polynomial")
}
coeffs := append([]curves.Scalar{}, f.secret)
for i := 1; i < f.threshold; i++ {
secret := f.curve.NewScalar()
secret = secret.Random(rand.Reader)
coeffs = append(coeffs, secret)
}
for i := 1; i <= f.total; i++ {
result := coeffs[0].Clone()
x := f.curve.Scalar.New(i)
for j := 1; j < f.threshold; j++ {
term := coeffs[j].Mul(x)
result = result.Add(term)
x = x.Mul(f.curve.Scalar.New(i))
}
if i == f.id {
f.scalar = result
} else {
fragBytes := result.Bytes()
f.fragsForCounterparties[i] = fragBytes
}
}
f.round = FELDMAN_ROUND_INITIALIZED
return nil
}
func (f *Feldman) Scalar() curves.Scalar {
return f.scalar
}
func (f *Feldman) GetPolyFrags() (map[int][]byte, error) {
if f.round != FELDMAN_ROUND_INITIALIZED {
return nil, errors.Wrap(ErrWrongRound, "get poly frags")
}
return f.fragsForCounterparties, nil
}
func (f *Feldman) SetPolyFragForParty(id int, frag []byte) ([]byte, error) {
if f.round != FELDMAN_ROUND_INITIALIZED {
return nil, errors.Wrap(ErrWrongRound, "set poly frag for party")
}
var err error
f.fragsFromCounterparties[id], err = f.curve.NewScalar().SetBytes(frag)
if err != nil {
return nil, errors.Wrap(err, "set poly frag for party")
}
if len(f.fragsFromCounterparties) == f.total-1 {
for _, v := range f.fragsFromCounterparties {
f.scalar = f.scalar.Add(v)
}
f.point = f.generator.Mul(f.scalar)
randCommitment := f.curve.NewScalar().Random(rand.Reader)
f.randomCommitmentPoint = f.generator.Mul(randCommitment)
randCommitmentPointBytes := f.randomCommitmentPoint.ToAffineCompressed()
publicPointBytes := f.point.ToAffineCompressed()
challenge := sha256.Sum256(
append(
append([]byte{}, publicPointBytes...),
randCommitmentPointBytes...,
),
)
challengeBig, err := f.curve.NewScalar().SetBigInt(
new(big.Int).SetBytes(challenge[:]),
)
if err != nil {
return nil, errors.Wrap(err, "set poly frag for party")
}
f.zkpok = f.scalar.Mul(challengeBig).Add(randCommitment)
zkpokBytes := f.zkpok.Bytes()
zkcommit := sha256.Sum256(
append(
append([]byte{}, randCommitmentPointBytes...),
zkpokBytes...,
),
)
f.round = FELDMAN_ROUND_COMMITTED
return zkcommit[:], nil
}
return []byte{}, nil
}
func (f *Feldman) ReceiveCommitments(
id int,
zkcommit []byte,
) (*FeldmanReveal, error) {
if f.round != FELDMAN_ROUND_COMMITTED {
return nil, errors.Wrap(ErrWrongRound, "receive commitments")
}
f.zkcommitsFromCounterparties[id] = zkcommit
if len(f.zkcommitsFromCounterparties) == f.total-1 {
publicPointBytes := f.point.ToAffineCompressed()
randCommitmentPointBytes := f.randomCommitmentPoint.ToAffineCompressed()
f.round = FELDMAN_ROUND_REVEALED
zkpokBytes := f.zkpok.Bytes()
return &FeldmanReveal{
Point: publicPointBytes,
RandomCommitmentPoint: randCommitmentPointBytes,
ZKPoK: zkpokBytes,
}, nil
}
return nil, nil
}
func (f *Feldman) Recombine(id int, reveal *FeldmanReveal) (bool, error) {
if f.round != FELDMAN_ROUND_REVEALED {
return false, errors.Wrap(ErrWrongRound, "recombine")
}
counterpartyPoint, err := f.curve.NewGeneratorPoint().FromAffineCompressed(
reveal.Point,
)
if err != nil {
return false, errors.Wrap(err, "recombine")
}
if counterpartyPoint.Equal(f.curve.NewGeneratorPoint()) ||
counterpartyPoint.Equal(f.generator) {
return false, errors.Wrap(errors.New("counterparty sent generator"), "recombine")
}
counterpartyRandomCommitmentPoint, err := f.curve.NewGeneratorPoint().
FromAffineCompressed(reveal.RandomCommitmentPoint)
if err != nil {
return false, errors.Wrap(err, "recombine")
}
if counterpartyRandomCommitmentPoint.Equal(f.curve.NewGeneratorPoint()) ||
counterpartyRandomCommitmentPoint.Equal(f.generator) {
return false, errors.Wrap(errors.New("counterparty sent generator"), "recombine")
}
counterpartyZKPoK, err := f.curve.NewScalar().SetBytes(reveal.ZKPoK)
if err != nil {
return false, errors.Wrap(err, "recombine")
}
counterpartyZKCommit := f.zkcommitsFromCounterparties[id]
challenge := sha256.Sum256(append(
append([]byte{}, reveal.Point...),
reveal.RandomCommitmentPoint...,
))
challengeBig, err := f.curve.NewScalar().SetBigInt(
new(big.Int).SetBytes(challenge[:]),
)
if err != nil {
return false, errors.Wrap(err, "recombine")
}
proof := f.generator.Mul(counterpartyZKPoK)
counterpartyRandomCommitmentPoint = counterpartyRandomCommitmentPoint.Add(
counterpartyPoint.Mul(challengeBig),
)
if !proof.Equal(counterpartyRandomCommitmentPoint) {
return false, errors.Wrap(
errors.New(fmt.Sprintf("invalid proof from %d", id)),
"recombine",
)
}
verifier := sha256.Sum256(append(
append([]byte{}, reveal.RandomCommitmentPoint...),
reveal.ZKPoK...,
))
if !bytes.Equal(counterpartyZKCommit, verifier[:]) {
return false, errors.Wrap(
errors.New(fmt.Sprintf("%d changed zkpok after commit", id)),
"recombine",
)
}
f.pointsFromCounterparties[id] = counterpartyPoint
if len(f.pointsFromCounterparties) == f.total-1 {
f.pointsFromCounterparties[f.id] = f.point
for i := 1; i <= f.total-f.threshold+1; i++ {
var reconstructedSum curves.Point = nil
for j := i; j < f.threshold+i; j++ {
num := f.curve.Scalar.One()
den := f.curve.Scalar.One()
for k := i; k < f.threshold+i; k++ {
if j != k {
j := f.curve.NewScalar().New(j)
k := f.curve.NewScalar().New(k)
num = num.Mul(k)
den = den.Mul(k.Sub(j))
}
}
den, _ = den.Invert()
reconstructedFragment := f.pointsFromCounterparties[j].Mul(num.Mul(den))
if reconstructedSum == nil {
reconstructedSum = reconstructedFragment
} else {
reconstructedSum = reconstructedSum.Add(reconstructedFragment)
}
}
if f.publicKey.Equal(f.curve.NewGeneratorPoint()) ||
f.publicKey.Equal(f.generator) {
f.publicKey = reconstructedSum
} else if !f.publicKey.Equal(reconstructedSum) {
return false, errors.Wrap(
errors.New("recombination mismatch"),
"recombine",
)
}
}
f.round = FELDMAN_ROUND_RECONSTRUCTED
}
return f.round == FELDMAN_ROUND_RECONSTRUCTED, nil
}
func (f *Feldman) PublicKey() curves.Point {
return f.publicKey
}
func (f *Feldman) PublicKeyBytes() []byte {
return f.publicKey.ToAffineCompressed()
}
func ReverseScalarBytes(inBytes []byte, length int) []byte {
outBytes := make([]byte, length)
for i, j := 0, len(inBytes)-1; j >= 0; i, j = i+1, j-1 {
outBytes[i] = inBytes[j]
}
return outBytes
}

View File

@ -0,0 +1,446 @@
package channel_test
import (
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
"source.quilibrium.com/quilibrium/monorepo/nekryptology/pkg/core/curves"
"source.quilibrium.com/quilibrium/monorepo/node/crypto"
)
func TestFeldman(t *testing.T) {
s1 := curves.ED25519().NewScalar().Random(rand.Reader)
f1, err := crypto.NewFeldman(
3,
5,
1,
s1,
*curves.ED25519(),
curves.ED25519().NewGeneratorPoint(),
)
assert.NoError(t, err)
s2 := curves.ED25519().NewScalar().Random(rand.Reader)
f2, err := crypto.NewFeldman(
3,
5,
2,
s2,
*curves.ED25519(),
curves.ED25519().NewGeneratorPoint(),
)
assert.NoError(t, err)
s3 := curves.ED25519().NewScalar().Random(rand.Reader)
f3, err := crypto.NewFeldman(
3,
5,
3,
s3,
*curves.ED25519(),
curves.ED25519().NewGeneratorPoint(),
)
assert.NoError(t, err)
s4 := curves.ED25519().NewScalar().Random(rand.Reader)
f4, err := crypto.NewFeldman(
3,
5,
4,
s4,
*curves.ED25519(),
curves.ED25519().NewGeneratorPoint(),
)
assert.NoError(t, err)
s5 := curves.ED25519().NewScalar().Random(rand.Reader)
f5, err := crypto.NewFeldman(
3,
5,
5,
s5,
*curves.ED25519(),
curves.ED25519().NewGeneratorPoint(),
)
assert.NoError(t, err)
err = f1.SamplePolynomial()
assert.NoError(t, err)
err = f2.SamplePolynomial()
assert.NoError(t, err)
err = f3.SamplePolynomial()
assert.NoError(t, err)
err = f4.SamplePolynomial()
assert.NoError(t, err)
err = f5.SamplePolynomial()
assert.NoError(t, err)
m1, err := f1.GetPolyFrags()
assert.NoError(t, err)
m2, err := f2.GetPolyFrags()
assert.NoError(t, err)
m3, err := f3.GetPolyFrags()
assert.NoError(t, err)
m4, err := f4.GetPolyFrags()
assert.NoError(t, err)
m5, err := f5.GetPolyFrags()
assert.NoError(t, err)
m1[1] = f1.Scalar().Bytes()
_, err = f1.SetPolyFragForParty(2, m2[1])
assert.NoError(t, err)
_, err = f1.SetPolyFragForParty(3, m3[1])
assert.NoError(t, err)
_, err = f1.SetPolyFragForParty(4, m4[1])
assert.NoError(t, err)
z1, err := f1.SetPolyFragForParty(5, m5[1])
assert.NoError(t, err)
_, err = f2.SetPolyFragForParty(1, m1[2])
assert.NoError(t, err)
_, err = f2.SetPolyFragForParty(3, m3[2])
assert.NoError(t, err)
_, err = f2.SetPolyFragForParty(4, m4[2])
assert.NoError(t, err)
z2, err := f2.SetPolyFragForParty(5, m5[2])
assert.NoError(t, err)
_, err = f3.SetPolyFragForParty(1, m1[3])
assert.NoError(t, err)
_, err = f3.SetPolyFragForParty(2, m2[3])
assert.NoError(t, err)
_, err = f3.SetPolyFragForParty(4, m4[3])
assert.NoError(t, err)
z3, err := f3.SetPolyFragForParty(5, m5[3])
assert.NoError(t, err)
_, err = f4.SetPolyFragForParty(1, m1[4])
assert.NoError(t, err)
_, err = f4.SetPolyFragForParty(2, m2[4])
assert.NoError(t, err)
_, err = f4.SetPolyFragForParty(3, m3[4])
assert.NoError(t, err)
z4, err := f4.SetPolyFragForParty(5, m5[4])
assert.NoError(t, err)
_, err = f5.SetPolyFragForParty(1, m1[5])
assert.NoError(t, err)
_, err = f5.SetPolyFragForParty(2, m2[5])
assert.NoError(t, err)
_, err = f5.SetPolyFragForParty(3, m3[5])
assert.NoError(t, err)
z5, err := f5.SetPolyFragForParty(4, m4[5])
assert.NoError(t, err)
_, err = f1.ReceiveCommitments(2, z2)
assert.NoError(t, err)
assert.NoError(t, err)
_, err = f1.ReceiveCommitments(3, z3)
assert.NoError(t, err)
assert.NoError(t, err)
_, err = f1.ReceiveCommitments(4, z4)
assert.NoError(t, err)
assert.NoError(t, err)
r1, err := f1.ReceiveCommitments(5, z5)
assert.NoError(t, err)
assert.NoError(t, err)
_, err = f2.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f2.ReceiveCommitments(3, z3)
assert.NoError(t, err)
_, err = f2.ReceiveCommitments(4, z4)
assert.NoError(t, err)
r2, err := f2.ReceiveCommitments(5, z5)
assert.NoError(t, err)
_, err = f3.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f3.ReceiveCommitments(2, z2)
assert.NoError(t, err)
_, err = f3.ReceiveCommitments(4, z4)
assert.NoError(t, err)
r3, err := f3.ReceiveCommitments(5, z5)
assert.NoError(t, err)
_, err = f4.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f4.ReceiveCommitments(2, z2)
assert.NoError(t, err)
_, err = f4.ReceiveCommitments(3, z3)
assert.NoError(t, err)
r4, err := f4.ReceiveCommitments(5, z5)
assert.NoError(t, err)
_, err = f5.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f5.ReceiveCommitments(2, z2)
assert.NoError(t, err)
_, err = f5.ReceiveCommitments(3, z3)
assert.NoError(t, err)
r5, err := f5.ReceiveCommitments(4, z4)
assert.NoError(t, err)
_, err = f1.Recombine(2, r2)
assert.NoError(t, err)
_, err = f1.Recombine(3, r3)
assert.NoError(t, err)
_, err = f1.Recombine(4, r4)
assert.NoError(t, err)
_, err = f1.Recombine(5, r5)
assert.NoError(t, err)
_, err = f2.Recombine(1, r1)
assert.NoError(t, err)
_, err = f2.Recombine(3, r3)
assert.NoError(t, err)
_, err = f2.Recombine(4, r4)
assert.NoError(t, err)
_, err = f2.Recombine(5, r5)
assert.NoError(t, err)
_, err = f3.Recombine(1, r1)
assert.NoError(t, err)
_, err = f3.Recombine(2, r2)
assert.NoError(t, err)
_, err = f3.Recombine(4, r4)
assert.NoError(t, err)
_, err = f3.Recombine(5, r5)
assert.NoError(t, err)
_, err = f4.Recombine(1, r1)
assert.NoError(t, err)
_, err = f4.Recombine(2, r2)
assert.NoError(t, err)
_, err = f4.Recombine(3, r3)
assert.NoError(t, err)
_, err = f4.Recombine(5, r5)
assert.NoError(t, err)
_, err = f5.Recombine(1, r1)
assert.NoError(t, err)
_, err = f5.Recombine(2, r2)
assert.NoError(t, err)
_, err = f5.Recombine(3, r3)
assert.NoError(t, err)
_, err = f5.Recombine(4, r4)
assert.NoError(t, err)
s := s1.Add(s2.Add(s3.Add(s4.Add(s5))))
assert.True(t, curves.ED25519().NewGeneratorPoint().Mul(s).Equal(f1.PublicKey()))
assert.True(t, f5.PublicKey().Equal(f1.PublicKey()))
}
func TestFeldmanCustomGenerator(t *testing.T) {
gen := curves.ED25519().Point.Random(rand.Reader)
f1, err := crypto.NewFeldman(
3,
5,
1,
curves.ED25519().NewScalar().Random(rand.Reader),
*curves.ED25519(),
gen,
)
assert.NoError(t, err)
f2, err := crypto.NewFeldman(
3,
5,
2,
curves.ED25519().NewScalar().Random(rand.Reader),
*curves.ED25519(),
gen,
)
assert.NoError(t, err)
f3, err := crypto.NewFeldman(
3,
5,
3,
curves.ED25519().NewScalar().Random(rand.Reader),
*curves.ED25519(),
gen,
)
assert.NoError(t, err)
f4, err := crypto.NewFeldman(
3,
5,
4,
curves.ED25519().NewScalar().Random(rand.Reader),
*curves.ED25519(),
gen,
)
assert.NoError(t, err)
f5, err := crypto.NewFeldman(
3,
5,
5,
curves.ED25519().NewScalar().Random(rand.Reader),
*curves.ED25519(),
gen,
)
assert.NoError(t, err)
err = f1.SamplePolynomial()
assert.NoError(t, err)
err = f2.SamplePolynomial()
assert.NoError(t, err)
err = f3.SamplePolynomial()
assert.NoError(t, err)
err = f4.SamplePolynomial()
assert.NoError(t, err)
err = f5.SamplePolynomial()
assert.NoError(t, err)
m1, err := f1.GetPolyFrags()
assert.NoError(t, err)
m2, err := f2.GetPolyFrags()
assert.NoError(t, err)
m3, err := f3.GetPolyFrags()
assert.NoError(t, err)
m4, err := f4.GetPolyFrags()
assert.NoError(t, err)
m5, err := f5.GetPolyFrags()
assert.NoError(t, err)
_, err = f1.SetPolyFragForParty(2, m2[1])
assert.NoError(t, err)
_, err = f1.SetPolyFragForParty(3, m3[1])
assert.NoError(t, err)
_, err = f1.SetPolyFragForParty(4, m4[1])
assert.NoError(t, err)
z1, err := f1.SetPolyFragForParty(5, m5[1])
assert.NoError(t, err)
_, err = f2.SetPolyFragForParty(1, m1[2])
assert.NoError(t, err)
_, err = f2.SetPolyFragForParty(3, m3[2])
assert.NoError(t, err)
_, err = f2.SetPolyFragForParty(4, m4[2])
assert.NoError(t, err)
z2, err := f2.SetPolyFragForParty(5, m5[2])
assert.NoError(t, err)
_, err = f3.SetPolyFragForParty(1, m1[3])
assert.NoError(t, err)
_, err = f3.SetPolyFragForParty(2, m2[3])
assert.NoError(t, err)
_, err = f3.SetPolyFragForParty(4, m4[3])
assert.NoError(t, err)
z3, err := f3.SetPolyFragForParty(5, m5[3])
assert.NoError(t, err)
_, err = f4.SetPolyFragForParty(1, m1[4])
assert.NoError(t, err)
_, err = f4.SetPolyFragForParty(2, m2[4])
assert.NoError(t, err)
_, err = f4.SetPolyFragForParty(3, m3[4])
assert.NoError(t, err)
z4, err := f4.SetPolyFragForParty(5, m5[4])
assert.NoError(t, err)
_, err = f5.SetPolyFragForParty(1, m1[5])
assert.NoError(t, err)
_, err = f5.SetPolyFragForParty(2, m2[5])
assert.NoError(t, err)
_, err = f5.SetPolyFragForParty(3, m3[5])
assert.NoError(t, err)
z5, err := f5.SetPolyFragForParty(4, m4[5])
assert.NoError(t, err)
_, err = f1.ReceiveCommitments(2, z2)
assert.NoError(t, err)
assert.NoError(t, err)
_, err = f1.ReceiveCommitments(3, z3)
assert.NoError(t, err)
assert.NoError(t, err)
_, err = f1.ReceiveCommitments(4, z4)
assert.NoError(t, err)
assert.NoError(t, err)
r1, err := f1.ReceiveCommitments(5, z5)
assert.NoError(t, err)
assert.NoError(t, err)
_, err = f2.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f2.ReceiveCommitments(3, z3)
assert.NoError(t, err)
_, err = f2.ReceiveCommitments(4, z4)
assert.NoError(t, err)
r2, err := f2.ReceiveCommitments(5, z5)
assert.NoError(t, err)
_, err = f3.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f3.ReceiveCommitments(2, z2)
assert.NoError(t, err)
_, err = f3.ReceiveCommitments(4, z4)
assert.NoError(t, err)
r3, err := f3.ReceiveCommitments(5, z5)
assert.NoError(t, err)
_, err = f4.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f4.ReceiveCommitments(2, z2)
assert.NoError(t, err)
_, err = f4.ReceiveCommitments(3, z3)
assert.NoError(t, err)
r4, err := f4.ReceiveCommitments(5, z5)
assert.NoError(t, err)
_, err = f5.ReceiveCommitments(1, z1)
assert.NoError(t, err)
_, err = f5.ReceiveCommitments(2, z2)
assert.NoError(t, err)
_, err = f5.ReceiveCommitments(3, z3)
assert.NoError(t, err)
r5, err := f5.ReceiveCommitments(4, z4)
assert.NoError(t, err)
_, err = f1.Recombine(2, r2)
assert.NoError(t, err)
_, err = f1.Recombine(3, r3)
assert.NoError(t, err)
_, err = f1.Recombine(4, r4)
assert.NoError(t, err)
_, err = f1.Recombine(5, r5)
assert.NoError(t, err)
_, err = f2.Recombine(1, r1)
assert.NoError(t, err)
_, err = f2.Recombine(3, r3)
assert.NoError(t, err)
_, err = f2.Recombine(4, r4)
assert.NoError(t, err)
_, err = f2.Recombine(5, r5)
assert.NoError(t, err)
_, err = f3.Recombine(1, r1)
assert.NoError(t, err)
_, err = f3.Recombine(2, r2)
assert.NoError(t, err)
_, err = f3.Recombine(4, r4)
assert.NoError(t, err)
_, err = f3.Recombine(5, r5)
assert.NoError(t, err)
_, err = f4.Recombine(1, r1)
assert.NoError(t, err)
_, err = f4.Recombine(2, r2)
assert.NoError(t, err)
_, err = f4.Recombine(3, r3)
assert.NoError(t, err)
_, err = f4.Recombine(5, r5)
assert.NoError(t, err)
_, err = f5.Recombine(1, r1)
assert.NoError(t, err)
_, err = f5.Recombine(2, r2)
assert.NoError(t, err)
_, err = f5.Recombine(3, r3)
assert.NoError(t, err)
_, err = f5.Recombine(4, r4)
assert.NoError(t, err)
}

View File

@ -0,0 +1,755 @@
package channel
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha512"
"crypto/subtle"
"encoding/binary"
"encoding/json"
"fmt"
"sort"
"github.com/pkg/errors"
"golang.org/x/crypto/hkdf"
"source.quilibrium.com/quilibrium/monorepo/nekryptology/pkg/core/curves"
"source.quilibrium.com/quilibrium/monorepo/node/keys"
"source.quilibrium.com/quilibrium/monorepo/node/protobufs"
)
const TRIPLE_RATCHET_PROTOCOL_VERSION = 1
const TRIPLE_RATCHET_PROTOCOL = 2<<8 + TRIPLE_RATCHET_PROTOCOL_VERSION
type TripleRatchetRound int
const (
TRIPLE_RATCHET_ROUND_UNINITIALIZED = TripleRatchetRound(0)
TRIPLE_RATCHET_ROUND_INITIALIZED = TripleRatchetRound(1)
TRIPLE_RATCHET_ROUND_COMMITTED = TripleRatchetRound(2)
TRIPLE_RATCHET_ROUND_REVEALED = TripleRatchetRound(3)
TRIPLE_RATCHET_ROUND_RECONSTRUCTED = TripleRatchetRound(4)
)
// Note: If an HSM with raw primitive access becomes available, the raw crypto
// mechanisms should be refactored into calls in KeyManager and implemented
// through the driver
type TripleRatchetParticipant struct {
peerKey curves.Scalar
sendingEphemeralPrivateKey curves.Scalar
receivingEphemeralKeys map[string]curves.Scalar
receivingGroupKey curves.Point
curve curves.Curve
keyManager keys.KeyManager
rootKey []byte
sendingChainKey []byte
currentHeaderKey []byte
nextHeaderKey []byte
receivingChainKey map[string][]byte
currentSendingChainLength uint32
previousSendingChainLength uint32
currentReceivingChainLength map[string]uint32
previousReceivingChainLength map[string]uint32
peerIdMap map[string]int
idPeerMap map[int]*PeerInfo
skippedKeysMap map[string]map[string]map[uint32][]byte
peerChannels map[string]*DoubleRatchetParticipant
dkgRatchet *Feldman
}
type PeerInfo struct {
PublicKey curves.Point
IdentityPublicKey curves.Point
SignedPrePublicKey curves.Point
}
// Weak-mode synchronous group modification TR this is not the asynchronous
// TR, does not ratchet group key automatically, know what your use case is
// before adopting this.
func NewTripleRatchetParticipant(
peers []*PeerInfo,
curve curves.Curve,
keyManager keys.KeyManager,
peerKey curves.Scalar,
identityKey curves.Scalar,
signedPreKey curves.Scalar,
) (
*TripleRatchetParticipant,
map[string]*protobufs.P2PChannelEnvelope,
error,
) {
participant := &TripleRatchetParticipant{}
participant.skippedKeysMap = make(map[string]map[string]map[uint32][]byte)
participant.receivingEphemeralKeys = make(map[string]curves.Scalar)
participant.receivingChainKey = make(map[string][]byte)
participant.peerChannels = make(map[string]*DoubleRatchetParticipant)
participant.keyManager = keyManager
participant.currentSendingChainLength = 0
participant.previousSendingChainLength = 0
participant.currentReceivingChainLength = make(map[string]uint32)
participant.previousReceivingChainLength = make(map[string]uint32)
peerBasis := append([]*PeerInfo{}, peers...)
peerBasis = append(peerBasis, &PeerInfo{
PublicKey: peerKey.Point().Generator().Mul(peerKey),
IdentityPublicKey: identityKey.Point().Generator().Mul(identityKey),
SignedPrePublicKey: signedPreKey.Point().Generator().Mul(signedPreKey),
})
sort.Slice(peerBasis, func(i, j int) bool {
return bytes.Compare(
peerBasis[i].PublicKey.ToAffineCompressed(),
peerBasis[j].PublicKey.ToAffineCompressed(),
) <= 0
})
initMessages := make(map[string]*protobufs.P2PChannelEnvelope)
peerIdMap := map[string]int{}
idPeerMap := map[int]*PeerInfo{}
sender := false
for i := 0; i < len(peerBasis); i++ {
peerIdMap[string(peerBasis[i].PublicKey.ToAffineCompressed())] = i + 1
idPeerMap[i+1] = peerBasis[i]
if bytes.Equal(
peerBasis[i].PublicKey.ToAffineCompressed(),
peerKey.Point().Generator().Mul(peerKey).ToAffineCompressed(),
) {
sender = true
} else {
participant.skippedKeysMap[string(
peerBasis[i].PublicKey.ToAffineCompressed(),
)] = make(map[string]map[uint32][]byte)
participant.currentReceivingChainLength[string(
peerBasis[i].PublicKey.ToAffineCompressed(),
)] = 0
participant.previousReceivingChainLength[string(
peerBasis[i].PublicKey.ToAffineCompressed(),
)] = 0
var sessionKey []byte
if sender {
sessionKey = SenderX3DH(
identityKey,
signedPreKey,
peerBasis[i].IdentityPublicKey,
peerBasis[i].SignedPrePublicKey,
96,
)
} else {
sessionKey = ReceiverX3DH(
identityKey,
signedPreKey,
peerBasis[i].IdentityPublicKey,
peerBasis[i].SignedPrePublicKey,
96,
)
}
var err error
participant.peerChannels[string(
peerBasis[i].PublicKey.ToAffineCompressed(),
)], err = NewDoubleRatchetParticipant(
sessionKey[:32],
sessionKey[32:64],
sessionKey[64:],
sender,
signedPreKey,
peerBasis[i].SignedPrePublicKey,
&curve,
keyManager,
)
if err != nil {
return nil, nil, errors.Wrap(err, "new triple ratchet participant")
}
if sender {
initMessages[string(peerBasis[i].PublicKey.ToAffineCompressed())], err =
participant.peerChannels[string(
peerBasis[i].PublicKey.ToAffineCompressed(),
)].RatchetEncrypt([]byte("init"))
if err != nil {
return nil, nil, errors.Wrap(err, "new triple ratchet participant")
}
}
}
}
feldman, err := NewFeldman(
2,
len(peers)+1,
peerIdMap[string(
peerKey.Point().Generator().Mul(peerKey).ToAffineCompressed(),
)],
curve.NewScalar().Random(rand.Reader),
curve,
curve.Point.Generator(),
)
if err != nil {
return nil, nil, errors.Wrap(err, "new triple ratchet participant")
}
participant.peerIdMap = peerIdMap
participant.idPeerMap = idPeerMap
participant.dkgRatchet = feldman
participant.curve = curve
participant.peerKey = peerKey
return participant, initMessages, nil
}
func (r *TripleRatchetParticipant) Initialize(
initMessages map[string]*protobufs.P2PChannelEnvelope,
) (map[string]*protobufs.P2PChannelEnvelope, error) {
for k, m := range initMessages {
msg, err := r.peerChannels[k].RatchetDecrypt(m)
if err != nil {
return nil, errors.Wrap(err, "initialize")
}
if string(msg) != "init" {
return nil, errors.Wrap(errors.New("invalid init message"), "initialize")
}
}
if err := r.dkgRatchet.SamplePolynomial(); err != nil {
return nil, errors.Wrap(err, "initialize")
}
result, err := r.dkgRatchet.GetPolyFrags()
if err != nil {
return nil, errors.Wrap(err, "initialize")
}
resultMap := make(map[string]*protobufs.P2PChannelEnvelope)
for k, v := range result {
if r.idPeerMap[k].PublicKey.Equal(
r.peerKey.Point().Generator().Mul(r.peerKey),
) {
continue
}
envelope, err := r.peerChannels[string(
r.idPeerMap[k].PublicKey.ToAffineCompressed(),
)].RatchetEncrypt(v)
if err != nil {
return nil, errors.Wrap(err, "initialize")
}
resultMap[string(r.idPeerMap[k].PublicKey.ToAffineCompressed())] = envelope
}
return resultMap, nil
}
func (r *TripleRatchetParticipant) ReceivePolyFrag(
peerId []byte,
frag *protobufs.P2PChannelEnvelope,
) (map[string]*protobufs.P2PChannelEnvelope, error) {
b, err := r.peerChannels[string(peerId)].RatchetDecrypt(frag)
if err != nil {
return nil, errors.Wrap(err, "receive poly frag")
}
result, err := r.dkgRatchet.SetPolyFragForParty(
r.peerIdMap[string(peerId)],
b,
)
if err != nil {
return nil, errors.Wrap(err, "receive poly frag")
}
if len(result) != 0 {
envelopes := make(map[string]*protobufs.P2PChannelEnvelope)
for k, c := range r.peerChannels {
envelope, err := c.RatchetEncrypt(result)
if err != nil {
return nil, errors.Wrap(err, "receive poly frag")
}
envelopes[k] = envelope
}
return envelopes, errors.Wrap(err, "receive poly frag")
}
return nil, nil
}
func (r *TripleRatchetParticipant) ReceiveCommitment(
peerId []byte,
zkcommit *protobufs.P2PChannelEnvelope,
) (map[string]*protobufs.P2PChannelEnvelope, error) {
b, err := r.peerChannels[string(peerId)].RatchetDecrypt(zkcommit)
if err != nil {
return nil, errors.Wrap(err, "receive commitment")
}
result, err := r.dkgRatchet.ReceiveCommitments(
r.peerIdMap[string(peerId)],
b,
)
if err != nil {
return nil, errors.Wrap(err, "receive commitment")
}
d, err := json.Marshal(result)
if err != nil {
return nil, errors.Wrap(err, "receive commitment")
}
if result != nil {
envelopes := make(map[string]*protobufs.P2PChannelEnvelope)
for k, c := range r.peerChannels {
envelope, err := c.RatchetEncrypt(d)
if err != nil {
return nil, errors.Wrap(err, "receive commitment")
}
envelopes[k] = envelope
}
return envelopes, errors.Wrap(err, "receive poly frag")
}
return nil, nil
}
func (r *TripleRatchetParticipant) Recombine(
peerId []byte,
reveal *protobufs.P2PChannelEnvelope,
) error {
b, err := r.peerChannels[string(peerId)].RatchetDecrypt(reveal)
if err != nil {
return errors.Wrap(err, "recombine")
}
rev := &FeldmanReveal{}
if err = json.Unmarshal(b, rev); err != nil {
return errors.Wrap(err, "recombine")
}
done, err := r.dkgRatchet.Recombine(
r.peerIdMap[string(peerId)],
rev,
)
if err != nil {
return errors.Wrap(err, "recombine")
}
if !done {
return nil
}
sess := sha512.Sum512_256(r.dkgRatchet.PublicKeyBytes())
hash := hkdf.New(
sha512.New,
r.dkgRatchet.PublicKeyBytes(),
sess[:],
[]byte("quilibrium-triple-ratchet"),
)
rkck := make([]byte, 96)
if _, err := hash.Read(rkck[:]); err != nil {
return errors.Wrap(err, "recombine")
}
r.rootKey = rkck[:32]
r.currentHeaderKey = rkck[32:64]
r.nextHeaderKey = rkck[64:]
r.receivingGroupKey = r.dkgRatchet.PublicKey()
r.sendingEphemeralPrivateKey = r.curve.Scalar.Random(rand.Reader)
return nil
}
func (r *TripleRatchetParticipant) RatchetEncrypt(
message []byte,
) (*protobufs.P2PChannelEnvelope, error) {
envelope := &protobufs.P2PChannelEnvelope{
ProtocolIdentifier: TRIPLE_RATCHET_PROTOCOL,
MessageHeader: &protobufs.MessageCiphertext{},
MessageBody: &protobufs.MessageCiphertext{},
}
newChainKey, messageKey, aeadKey := ratchetKeys(r.sendingChainKey)
r.sendingChainKey = newChainKey
var err error
header := r.encodeHeader()
envelope.MessageHeader, err = r.encrypt(
header,
r.currentHeaderKey,
nil,
)
if err != nil {
return nil, errors.Wrap(err, "could not encrypt header")
}
envelope.MessageBody, err = r.encrypt(
message,
messageKey,
append(append([]byte{}, aeadKey...), envelope.MessageHeader.Ciphertext...),
)
if err != nil {
return nil, errors.Wrap(err, "could not encrypt message")
}
r.currentSendingChainLength++
return envelope, nil
}
func (r *TripleRatchetParticipant) RatchetDecrypt(
envelope *protobufs.P2PChannelEnvelope,
) ([]byte, error) {
plaintext, err := r.trySkippedMessageKeys(envelope)
if err != nil {
return nil, errors.Wrap(err, "ratchet decrypt")
}
if plaintext != nil {
return plaintext, nil
}
header, shouldRatchet, err := r.decryptHeader(
envelope.MessageHeader,
r.currentHeaderKey,
)
if err != nil {
return nil, errors.Wrap(err, "ratchet decrypt")
}
senderKey,
receivingEphemeralKey,
previousReceivingChainLength,
currentReceivingChainLength,
err := r.decodeHeader(header)
if err != nil {
return nil, errors.Wrap(err, "ratchet decrypt")
}
if shouldRatchet {
if err := r.skipMessageKeys(
senderKey,
previousReceivingChainLength,
); err != nil {
return nil, errors.Wrap(err, "ratchet decrypt")
}
if err := r.ratchetReceiverEphemeralKeys(
senderKey,
receivingEphemeralKey,
); err != nil {
return nil, errors.Wrap(err, "ratchet decrypt")
}
}
if err := r.skipMessageKeys(
senderKey,
currentReceivingChainLength,
); err != nil {
return nil, errors.Wrap(err, "ratchet decrypt")
}
newChainKey, messageKey, aeadKey := ratchetKeys(
r.receivingChainKey[string(senderKey.ToAffineCompressed())],
)
r.receivingChainKey[string(senderKey.ToAffineCompressed())] = newChainKey
r.currentReceivingChainLength[string(senderKey.ToAffineCompressed())]++
plaintext, err = r.decrypt(
envelope.MessageBody,
messageKey,
append(
append([]byte{}, aeadKey...),
envelope.MessageHeader.Ciphertext...,
),
)
return plaintext, errors.Wrap(err, "ratchet decrypt")
}
func (r *TripleRatchetParticipant) ratchetSenderEphemeralKeys() error {
hash := hkdf.New(
sha512.New,
r.receivingGroupKey.Mul(
r.sendingEphemeralPrivateKey,
).ToAffineCompressed(),
r.rootKey,
[]byte("quilibrium-triple-ratchet"),
)
rkck2 := make([]byte, 96)
if _, err := hash.Read(rkck2[:]); err != nil {
return errors.Wrap(err, "failed ratcheting root key")
}
r.rootKey = rkck2[:32]
r.sendingChainKey = rkck2[32:64]
r.nextHeaderKey = rkck2[64:]
return nil
}
func (r *TripleRatchetParticipant) ratchetReceiverEphemeralKeys(
peerKey curves.Point,
newEphemeralKey curves.Scalar,
) error {
r.previousSendingChainLength = r.currentSendingChainLength
r.currentSendingChainLength = 0
r.currentReceivingChainLength[string(peerKey.ToAffineCompressed())] = 0
r.currentHeaderKey = r.nextHeaderKey
r.receivingEphemeralKeys[string(
peerKey.ToAffineCompressed(),
)] = newEphemeralKey
hash := hkdf.New(
sha512.New,
r.receivingGroupKey.Mul(
newEphemeralKey,
).ToAffineCompressed(),
r.rootKey,
[]byte("quilibrium-triple-ratchet"),
)
rkck := make([]byte, 96)
if _, err := hash.Read(rkck[:]); err != nil {
return errors.Wrap(err, "failed ratcheting root key")
}
r.rootKey = rkck[:32]
r.receivingChainKey[string(peerKey.ToAffineCompressed())] = rkck[32:64]
r.nextHeaderKey = rkck[64:]
r.sendingEphemeralPrivateKey = r.curve.NewScalar().Random(rand.Reader)
return nil
}
func (r *TripleRatchetParticipant) trySkippedMessageKeys(
envelope *protobufs.P2PChannelEnvelope,
) ([]byte, error) {
for receivingHeaderKey, skippedKeys := range r.skippedKeysMap {
header, _, err := r.decryptHeader(
envelope.MessageHeader,
[]byte(receivingHeaderKey),
)
if err == nil {
peerKey, _, _, current, err := r.decodeHeader(header)
if err != nil {
return nil, errors.Wrap(err, "try skipped message keys")
}
messageKey := skippedKeys[string(
peerKey.ToAffineCompressed(),
)][current][:32]
aeadKey := skippedKeys[string(
peerKey.ToAffineCompressed(),
)][current][32:]
plaintext, err := r.decrypt(
envelope.MessageBody,
messageKey,
append(
append([]byte{}, aeadKey...),
envelope.MessageHeader.Ciphertext[:]...,
),
)
if err != nil {
return nil, errors.Wrap(err, "try skipped message keys")
}
delete(r.skippedKeysMap[string(
peerKey.ToAffineCompressed(),
)][receivingHeaderKey], current)
if len(r.skippedKeysMap[string(
peerKey.ToAffineCompressed(),
)][receivingHeaderKey]) == 0 {
delete(r.skippedKeysMap[string(
peerKey.ToAffineCompressed(),
)], receivingHeaderKey)
}
return plaintext, nil
}
}
return nil, nil
}
func (r *TripleRatchetParticipant) skipMessageKeys(
senderKey curves.Point,
until uint32,
) error {
if r.currentReceivingChainLength[string(
senderKey.ToAffineCompressed(),
)]+100 < until {
return errors.Wrap(errors.New("skip limit exceeded"), "skip message keys")
}
if r.receivingChainKey != nil {
for r.currentReceivingChainLength[string(
senderKey.ToAffineCompressed(),
)] < until {
newChainKey, messageKey, aeadKey := ratchetKeys(
r.receivingChainKey[string(
senderKey.ToAffineCompressed(),
)],
)
skippedKeys := r.skippedKeysMap[string(
senderKey.ToAffineCompressed(),
)][string(r.currentHeaderKey)]
if skippedKeys == nil {
r.skippedKeysMap[string(
senderKey.ToAffineCompressed(),
)][string(r.currentHeaderKey)] =
make(map[uint32][]byte)
}
skippedKeys[r.currentReceivingChainLength[string(
senderKey.ToAffineCompressed(),
)]] = append(
append([]byte{}, messageKey...),
aeadKey...,
)
r.receivingChainKey[string(
senderKey.ToAffineCompressed(),
)] = newChainKey
r.currentReceivingChainLength[string(
senderKey.ToAffineCompressed(),
)]++
}
}
return nil
}
func (r *TripleRatchetParticipant) encodeHeader() []byte {
header := []byte{}
header = append(
header,
r.peerKey.Point().Generator().Mul(r.peerKey).ToAffineCompressed()...,
)
header = append(
header,
r.sendingEphemeralPrivateKey.Bytes()...,
)
header = binary.BigEndian.AppendUint32(header, r.previousSendingChainLength)
header = binary.BigEndian.AppendUint32(header, r.currentSendingChainLength)
return header
}
func (r *TripleRatchetParticipant) decryptHeader(
ciphertext *protobufs.MessageCiphertext,
receivingHeaderKey []byte,
) ([]byte, bool, error) {
header, err := r.decrypt(
ciphertext,
receivingHeaderKey,
nil,
)
if err != nil && subtle.ConstantTimeCompare(
r.currentHeaderKey,
receivingHeaderKey,
) == 1 {
if header, err = r.decrypt(
ciphertext,
r.nextHeaderKey,
nil,
); err != nil {
return nil, false, errors.Wrap(err, "could not decrypt header")
}
fmt.Println("should ratchet")
return header, true, nil
}
return header, false, errors.Wrap(err, "could not decrypt header")
}
func (r *TripleRatchetParticipant) decodeHeader(
header []byte,
) (curves.Point, curves.Scalar, uint32, uint32, error) {
if len(header) < 9 {
return nil, nil, 0, 0, errors.Wrap(
errors.New("malformed header"),
"decode header",
)
}
currentReceivingChainLength := binary.BigEndian.Uint32(header[len(header)-4:])
previousReceivingChainLength := binary.BigEndian.Uint32(
header[len(header)-8 : len(header)-4],
)
sender := header[:len(r.curve.Point.ToAffineCompressed())]
senderKey, err := r.curve.Point.FromAffineCompressed(sender)
if err != nil {
return nil, nil, 0, 0, errors.Wrap(err, "decode header")
}
receivingEphemeralKeyBytes := header[len(
r.curve.Point.ToAffineCompressed(),
) : len(header)-8]
receivingEphemeralKey, err := r.curve.Scalar.Clone().SetBytes(
receivingEphemeralKeyBytes,
)
return senderKey,
receivingEphemeralKey,
previousReceivingChainLength,
currentReceivingChainLength,
errors.Wrap(err, "decode header")
}
func (r *TripleRatchetParticipant) encrypt(
plaintext []byte,
key []byte,
associatedData []byte,
) (*protobufs.MessageCiphertext, error) {
iv := [12]byte{}
rand.Read(iv[:])
aesCipher, err := aes.NewCipher(key)
if err != nil {
return nil, errors.Wrap(err, "encrypt")
}
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, errors.Wrap(err, "encrypt")
}
ciphertext := &protobufs.MessageCiphertext{}
if associatedData == nil {
associatedData = make([]byte, 32)
if _, err := rand.Read(associatedData); err != nil {
return nil, errors.Wrap(err, "encrypt")
}
ciphertext.AssociatedData = associatedData
}
ciphertext.Ciphertext = gcm.Seal(nil, iv[:], plaintext, associatedData)
ciphertext.InitializationVector = iv[:]
return ciphertext, nil
}
func (r *TripleRatchetParticipant) decrypt(
ciphertext *protobufs.MessageCiphertext,
key []byte,
associatedData []byte,
) ([]byte, error) {
if associatedData == nil {
associatedData = ciphertext.AssociatedData
}
aesCipher, err := aes.NewCipher(key)
if err != nil {
return nil, errors.Wrap(err, "decrypt")
}
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, errors.Wrap(err, "decrypt")
}
plaintext, err := gcm.Open(
nil,
ciphertext.InitializationVector,
ciphertext.Ciphertext,
associatedData,
)
return plaintext, errors.Wrap(err, "decrypt")
}

View File

@ -34,7 +34,10 @@ require (
) )
require ( require (
github.com/deiu/gon3 v0.0.0-20230411081920-f0f8f879f597 // indirect
github.com/deiu/rdf2go v0.0.0-20240619132609-81222e324bb9 // indirect
github.com/hashicorp/golang-lru/arc/v2 v2.0.7 // indirect github.com/hashicorp/golang-lru/arc/v2 v2.0.7 // indirect
github.com/linkeddata/gojsonld v0.0.0-20170418210642-4f5db6791326 // indirect
github.com/pion/datachannel v1.5.6 // indirect github.com/pion/datachannel v1.5.6 // indirect
github.com/pion/dtls/v2 v2.2.11 // indirect github.com/pion/dtls/v2 v2.2.11 // indirect
github.com/pion/ice/v2 v2.3.24 // indirect github.com/pion/ice/v2 v2.3.24 // indirect
@ -51,6 +54,7 @@ require (
github.com/pion/transport/v2 v2.2.5 // indirect github.com/pion/transport/v2 v2.2.5 // indirect
github.com/pion/turn/v2 v2.1.6 // indirect github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pion/webrtc/v3 v3.2.40 // indirect github.com/pion/webrtc/v3 v3.2.40 // indirect
github.com/rychipman/easylex v0.0.0-20160129204217-49ee7767142f // indirect
go.opentelemetry.io/otel v1.14.0 // indirect go.opentelemetry.io/otel v1.14.0 // indirect
go.opentelemetry.io/otel/trace v1.14.0 // indirect go.opentelemetry.io/otel/trace v1.14.0 // indirect
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.4.0 // indirect

View File

@ -90,6 +90,10 @@ github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPc
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218= github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218=
github.com/deiu/gon3 v0.0.0-20230411081920-f0f8f879f597 h1:xKCSqM+c9FjQIr0Qacn9m7x0kv/opDWGr/nvCowFCok=
github.com/deiu/gon3 v0.0.0-20230411081920-f0f8f879f597/go.mod h1:r8Pv5x6dxChq4mb1ZqzTyK3y9w8wDzWt55XAJpfSq34=
github.com/deiu/rdf2go v0.0.0-20240619132609-81222e324bb9 h1:xs255gi9FPRuCW+Ud8lQOBXBGHqM8cqqmoRfGokK3f0=
github.com/deiu/rdf2go v0.0.0-20240619132609-81222e324bb9/go.mod h1:d+9YsU6N5OuirjLEOp23T2/+S7OLByerfuv1f89iy90=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
@ -291,6 +295,8 @@ github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQsc
github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU=
github.com/libp2p/go-yamux/v4 v4.0.1 h1:FfDR4S1wj6Bw2Pqbc8Uz7pCxeRBPbwsBbEdfwiCypkQ= github.com/libp2p/go-yamux/v4 v4.0.1 h1:FfDR4S1wj6Bw2Pqbc8Uz7pCxeRBPbwsBbEdfwiCypkQ=
github.com/libp2p/go-yamux/v4 v4.0.1/go.mod h1:NWjl8ZTLOGlozrXSOZ/HlfG++39iKNnM5wwmtQP1YB4= github.com/libp2p/go-yamux/v4 v4.0.1/go.mod h1:NWjl8ZTLOGlozrXSOZ/HlfG++39iKNnM5wwmtQP1YB4=
github.com/linkeddata/gojsonld v0.0.0-20170418210642-4f5db6791326 h1:YP3lfXXYiQV5MKeUqVnxRP5uuMQTLPx+PGYm1UBoU98=
github.com/linkeddata/gojsonld v0.0.0-20170418210642-4f5db6791326/go.mod h1:nfqkuSNlsk1bvti/oa7TThx4KmRMBmSxf3okHI9wp3E=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
@ -464,6 +470,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/rychipman/easylex v0.0.0-20160129204217-49ee7767142f h1:L2/fBPABieQnQzfV40k2Zw7IcvZbt0CN5TgwUl8zDCs=
github.com/rychipman/easylex v0.0.0-20160129204217-49ee7767142f/go.mod h1:MZ2GRTcqmve6EoSbErWgCR+Ash4p8Gc5esHe8MDErss=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=

451
node/schema/rdf.go Normal file
View File

@ -0,0 +1,451 @@
package schema
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/deiu/rdf2go"
"github.com/pkg/errors"
)
type RDFParser interface {
Validate(document string) (bool, error)
}
type TurtleRDFParser struct {
}
type Field struct {
Name string
Type string
Size uint32
Comment string
Annotation string
RdfType string
Order int
ClassUrl rdf2go.Term
}
const RdfNS = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
const RdfsNS = "http://www.w3.org/2000/01/rdf-schema#"
const SchemaRepositoryNS = "https://types.quilibrium.com/schema-repository/"
const QCLNS = "https://types.quilibrium.com/qcl/"
const Prefix = "<%s>"
const TupleString = "%s%s"
const NTupleString = "<%s%s>"
var rdfTypeN = fmt.Sprintf(NTupleString, RdfNS, "type")
var rdfsClassN = fmt.Sprintf(NTupleString, RdfsNS, "Class")
var rdfsPropertyN = fmt.Sprintf(NTupleString, RdfsNS, "Property")
var rdfsDomainN = fmt.Sprintf(NTupleString, RdfsNS, "domain")
var rdfsRangeN = fmt.Sprintf(NTupleString, RdfsNS, "range")
var rdfsCommentN = fmt.Sprintf(NTupleString, RdfsNS, "comment")
var rdfType = fmt.Sprintf(TupleString, RdfNS, "type")
var rdfsClass = fmt.Sprintf(TupleString, RdfsNS, "Class")
var rdfsProperty = fmt.Sprintf(TupleString, RdfsNS, "Property")
var rdfsDomain = fmt.Sprintf(TupleString, RdfsNS, "domain")
var rdfsRange = fmt.Sprintf(TupleString, RdfsNS, "range")
var qclSize = fmt.Sprintf(TupleString, QCLNS, "size")
var qclOrder = fmt.Sprintf(TupleString, QCLNS, "order")
var rdfsComment = fmt.Sprintf(TupleString, RdfsNS, "comment")
var qclRdfTypeMap = map[string]string{
"Uint": "uint%d",
"Int": "int%d",
"ByteArray": "[%d]byte",
"Bool": "bool",
"Float": "float%d",
"String": "string",
"Struct": "struct",
}
func (t *TurtleRDFParser) Validate(document string) (bool, error) {
g := rdf2go.NewGraph("https://types.quilibrium.com/schema-repository/")
reader := strings.NewReader(document)
err := g.Parse(reader, "text/turtle")
if err != nil {
return false, errors.Wrap(err, "validate")
}
return true, nil
}
func (t *TurtleRDFParser) GenerateQCL(document string) (string, error) {
g := rdf2go.NewGraph("https://types.quilibrium.com/schema-repository/")
reader := strings.NewReader(document)
err := g.Parse(reader, "text/turtle")
if err != nil {
return "", errors.Wrap(err, "validate")
}
prefixMap := make(map[string]string)
for _, line := range strings.Split(document, "\n") {
parts := strings.Split(line, " ")
switch parts[0] {
case "PREFIX":
if len(parts) != 3 {
return "", errors.Wrap(err, "invalid PREFIX line")
}
prefixMap[strings.Trim(parts[2], "<>")] = parts[1]
}
}
iter := g.IterTriples()
classes := []string{}
classTerms := []rdf2go.Term{}
classUrls := []string{}
fields := make(map[string]map[string]*Field)
for a := range iter {
if a.Predicate.String() == rdfTypeN &&
a.Object.String() == rdfsClassN {
subj := a.Subject.RawValue()
parts := strings.Split(subj, "#")
className := parts[len(parts)-1]
parts = strings.Split(className, "/")
className = parts[len(parts)-1]
classUrl := subj[:len(subj)-len(className)]
classes = append(classes, className)
classUrls = append(classUrls, classUrl)
classTerms = append(classTerms, a.Subject)
}
}
for i, c := range classTerms {
for _, prop := range g.All(nil, rdf2go.NewResource(rdfsRange), c) {
subj := prop.Subject.RawValue()
parts := strings.Split(subj, "#")
className := parts[len(parts)-1]
parts = strings.Split(className, "/")
className = parts[len(parts)-1]
classUrl := subj[:len(subj)-len(className)]
if _, ok := fields[classes[i]]; !ok {
fields[classes[i]] = make(map[string]*Field)
}
fields[classes[i]][className] = &Field{
Name: className,
ClassUrl: prop.Subject,
Annotation: prefixMap[classUrl] + className,
Order: -1,
}
}
}
for _, class := range fields {
for fieldName, field := range class {
// scan the types
for _, prop := range g.All(field.ClassUrl, rdf2go.NewResource(
rdfsDomain,
), nil) {
obj := prop.Object.RawValue()
parts := strings.Split(obj, "#")
className := parts[len(parts)-1]
parts = strings.Split(className, "/")
className = parts[len(parts)-1]
classUrl := obj[:len(obj)-len(className)]
switch classUrl {
case QCLNS:
field.Type = qclRdfTypeMap[className]
for _, sprop := range g.All(field.ClassUrl, rdf2go.NewResource(
qclSize,
), nil) {
sobj := sprop.Object.RawValue()
parts := strings.Split(sobj, "#")
size := parts[len(parts)-1]
parts = strings.Split(size, "/")
size = parts[len(parts)-1]
s, err := strconv.Atoi(size)
fieldSize := s
if className != "String" && className != "ByteArray" && className != "Struct" {
fieldSize *= 8
}
if err != nil || s < 1 {
return "", errors.Wrap(
fmt.Errorf(
"invalid size for %s: %s",
fieldName,
size,
),
"generate qcl",
)
}
if strings.Contains(field.Type, "%") {
field.Type = fmt.Sprintf(field.Type, fieldSize)
}
field.RdfType = className
field.Size = uint32(s)
}
if strings.Contains(field.Type, "%d") {
return "", errors.Wrap(
fmt.Errorf(
"size unspecified for %s, add a qcl:size predicate",
fieldName,
),
"generate qcl",
)
}
case RdfsNS:
if className != "Literal" {
return "", errors.Wrap(
fmt.Errorf(
"invalid property type for %s: %s",
fieldName,
className,
),
"generate qcl",
)
}
field.Type = className
default:
field.Type = "hypergraph.Extrinsic"
field.Annotation += ",extrinsic=" + prefixMap[classUrl] + className
field.Size = 32
field.RdfType = "Struct"
}
break
}
for _, sprop := range g.All(field.ClassUrl, rdf2go.NewResource(
qclOrder,
), nil) {
sobj := sprop.Object.RawValue()
parts := strings.Split(sobj, "#")
order := parts[len(parts)-1]
parts = strings.Split(order, "/")
order = parts[len(parts)-1]
o, err := strconv.Atoi(order)
fieldOrder := o
if err != nil || o < 0 {
return "", errors.Wrap(
fmt.Errorf(
"invalid order for %s: %s",
fieldName,
order,
),
"generate qcl",
)
}
field.Order = fieldOrder
}
if field.Order < 0 {
return "", errors.Wrap(
fmt.Errorf(
"field order unspecified for %s, add a qcl:order predicate",
fieldName,
),
"generate qcl",
)
}
for _, prop := range g.All(field.ClassUrl, rdf2go.NewResource(
rdfsComment,
), nil) {
field.Comment = prop.Object.String()
}
}
}
output := "package main\n\n"
sort.Slice(classes, func(i, j int) bool {
return strings.Compare(classes[i], classes[j]) < 0
})
for _, class := range classes {
output += fmt.Sprintf("type %s struct {\n", class)
sortedFields := []*Field{}
for _, field := range fields[class] {
sortedFields = append(sortedFields, field)
}
sort.Slice(sortedFields, func(i, j int) bool {
return sortedFields[i].Order < sortedFields[j].Order
})
for _, field := range sortedFields {
if field.Comment != "" {
output += fmt.Sprintf(" // %s\n", field.Comment)
}
output += fmt.Sprintf(
" %s %s `rdf:\"%s\"`\n",
field.Name,
field.Type,
field.Annotation,
)
}
output += "}\n\n"
}
for _, class := range classes {
totalSize := uint32(0)
for _, field := range fields[class] {
totalSize += field.Size
}
output += fmt.Sprintf(
"func Unmarshal%s(payload [%d]byte) %s {\n result := %s{}\n",
class,
totalSize,
class,
class,
)
s := uint32(0)
sortedFields := []*Field{}
for _, field := range fields[class] {
sortedFields = append(sortedFields, field)
}
sort.Slice(sortedFields, func(i, j int) bool {
return sortedFields[i].Order < sortedFields[j].Order
})
for _, field := range sortedFields {
sizedType := ""
switch field.RdfType {
case "Uint":
sizedType = fmt.Sprintf(
"binary.GetUint(payload[%d:%d])",
s,
s+field.Size,
)
s += field.Size
case "Int":
sizedType = fmt.Sprintf(
"int%d(binary.GetUint(payload[%d:%d]))",
field.Size,
s,
s+field.Size,
)
s += field.Size
case "ByteArray":
sizedType = fmt.Sprintf(
"payload[%d:%d]",
s,
s+field.Size,
)
s += field.Size
case "Bool":
sizedType = "bool"
s++
case "Float":
sizedType = fmt.Sprintf(
"payload[%d:%d]",
s,
s+field.Size,
)
s += field.Size
case "String":
sizedType = fmt.Sprintf(
"string(payload[%d:%d])",
s,
s+field.Size,
)
s += field.Size
case "Struct":
sizedType = fmt.Sprintf(
"hypergraph.Extrinsic{}\n result.%s.Ref = payload[%d:%d]",
field.Name,
s,
s+field.Size,
)
s += field.Size
}
output += fmt.Sprintf(
" result.%s = %s\n",
field.Name,
sizedType,
)
}
output += " return result\n}\n\n"
}
for _, class := range classes {
totalSize := uint32(0)
for _, field := range fields[class] {
totalSize += field.Size
}
output += fmt.Sprintf(
"func Marshal%s(obj %s) [%d]byte {\n",
class,
class,
totalSize,
)
s := uint32(0)
sortedFields := []*Field{}
for _, field := range fields[class] {
sortedFields = append(sortedFields, field)
}
sort.Slice(sortedFields, func(i, j int) bool {
return sortedFields[i].Order < sortedFields[j].Order
})
output += fmt.Sprintf(" buf := make([]byte, %d)\n", totalSize)
for _, field := range sortedFields {
sizedType := ""
switch field.RdfType {
case "Uint":
sizedType = fmt.Sprintf(
"binary.PutUint(buf, %d, obj.%s)",
s,
field.Name,
)
s += field.Size
case "Int":
sizedType = fmt.Sprintf(
"binary.PutInt(buf, %d, obj.%s)",
s,
field.Name,
)
s += field.Size
case "ByteArray":
sizedType = fmt.Sprintf(
"copy(buf[%d:%d], obj.%s)",
s,
s+field.Size,
field.Name,
)
s += field.Size
case "Bool":
sizedType = fmt.Sprintf(
"if obj.%s { buf[%d] = 0xff } else { buf[%d] = 0x00 }",
field.Name,
s,
s,
)
s++
case "Float":
sizedType = fmt.Sprintf(
"copy(buf[%d:%d], obj.%s)",
s,
s+field.Size,
field.Name,
)
s += field.Size
case "String":
sizedType = fmt.Sprintf(
"copy(buf[%d:%d], []byte(obj.%s))",
s,
s+field.Size,
field.Name,
)
s += field.Size
case "Struct":
sizedType = fmt.Sprintf(
"copy(buf[%d:%d], obj.%s.Ref)",
s,
s+field.Size,
field.Name,
)
s += field.Size
}
output += fmt.Sprintf(
" %s\n",
sizedType,
)
}
output += " return buf\n}\n\n"
}
return output, nil
}