0g-chain/x/swap/keeper/invariants.go
Ruaridh ffef832d45
Upgrade to sdk v0.44.5 and add IBC ()
- Upgrade cosmos-sdk to v0.44.5 from v0.39.2
- Add Legacy Tx Endpoint for backwards compatibility
- Add IBC v1.2.3 Support

Co-authored-by: DracoLi <draco@dracoli.com>
Co-authored-by: drklee3 <derrick@dlee.dev>
Co-authored-by: denalimarsh <denalimarsh@gmail.com>
Co-authored-by: Draco Li <draco@kava.io>
Co-authored-by: Nick DeLuca <nickdeluca08@gmail.com>
Co-authored-by: Kevin Davis <karzak@users.noreply.github.com>
Co-authored-by: Denali Marsh <denali@kava.io>
2022-01-07 17:39:27 -07:00

139 lines
3.8 KiB
Go

package keeper
import (
"github.com/kava-labs/kava/x/swap/types"
sdk "github.com/cosmos/cosmos-sdk/types"
)
// RegisterInvariants registers the swap module invariants
func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
ir.RegisterRoute(types.ModuleName, "pool-records", PoolRecordsInvariant(k))
ir.RegisterRoute(types.ModuleName, "share-records", ShareRecordsInvariant(k))
ir.RegisterRoute(types.ModuleName, "pool-reserves", PoolReservesInvariant(k))
ir.RegisterRoute(types.ModuleName, "pool-shares", PoolSharesInvariant(k))
}
// AllInvariants runs all invariants of the swap module
func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) (string, bool) {
if res, stop := PoolRecordsInvariant(k)(ctx); stop {
return res, stop
}
if res, stop := ShareRecordsInvariant(k)(ctx); stop {
return res, stop
}
if res, stop := PoolReservesInvariant(k)(ctx); stop {
return res, stop
}
res, stop := PoolSharesInvariant(k)(ctx)
return res, stop
}
}
// PoolRecordsInvariant iterates all pool records and asserts that they are valid
func PoolRecordsInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "validate pool records broken", "pool record invalid")
return func(ctx sdk.Context) (string, bool) {
k.IteratePools(ctx, func(record types.PoolRecord) bool {
if err := record.Validate(); err != nil {
broken = true
return true
}
return false
})
return message, broken
}
}
// ShareRecordsInvariant iterates all share records and asserts that they are valid
func ShareRecordsInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "validate share records broken", "share record invalid")
return func(ctx sdk.Context) (string, bool) {
k.IterateDepositorShares(ctx, func(record types.ShareRecord) bool {
if err := record.Validate(); err != nil {
broken = true
return true
}
return false
})
return message, broken
}
}
// PoolReservesInvariant iterates all pools and ensures the total reserves matches the module account coins
func PoolReservesInvariant(k Keeper) sdk.Invariant {
message := sdk.FormatInvariant(types.ModuleName, "pool reserves broken", "pool reserves do not match module account")
return func(ctx sdk.Context) (string, bool) {
balance := k.bankKeeper.GetAllBalances(ctx, k.GetSwapModuleAccount(ctx).GetAddress())
reserves := sdk.Coins{}
k.IteratePools(ctx, func(record types.PoolRecord) bool {
for _, coin := range record.Reserves() {
reserves = reserves.Add(coin)
}
return false
})
broken := !reserves.IsEqual(balance)
return message, broken
}
}
type poolShares struct {
totalShares sdk.Int
totalSharesOwned sdk.Int
}
// PoolSharesInvariant iterates all pools and shares and ensures the total pool shares match the sum of depositor shares
func PoolSharesInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "pool shares broken", "pool shares do not match depositor shares")
return func(ctx sdk.Context) (string, bool) {
totalShares := make(map[string]poolShares)
k.IteratePools(ctx, func(pr types.PoolRecord) bool {
totalShares[pr.PoolID] = poolShares{
totalShares: pr.TotalShares,
totalSharesOwned: sdk.ZeroInt(),
}
return false
})
k.IterateDepositorShares(ctx, func(sr types.ShareRecord) bool {
if shares, found := totalShares[sr.PoolID]; found {
shares.totalSharesOwned = shares.totalSharesOwned.Add(sr.SharesOwned)
totalShares[sr.PoolID] = shares
} else {
totalShares[sr.PoolID] = poolShares{
totalShares: sdk.ZeroInt(),
totalSharesOwned: sr.SharesOwned,
}
}
return false
})
for _, ps := range totalShares {
if !ps.totalShares.Equal(ps.totalSharesOwned) {
broken = true
break
}
}
return message, broken
}
}