Add Invariants for Swap Module (#979)

* add swap module invariants

* typo

* update alias file for invariants

* typo in test name

* fix typo - method iterates share record, not pools
This commit is contained in:
Nick DeLuca 2021-08-05 20:43:55 -05:00 committed by GitHub
parent 56463eca14
commit b86cfc9f14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 423 additions and 20 deletions

View File

@ -24,6 +24,7 @@ const (
EventTypeSwapWithdraw = types.EventTypeSwapWithdraw EventTypeSwapWithdraw = types.EventTypeSwapWithdraw
ModuleAccountName = types.ModuleAccountName ModuleAccountName = types.ModuleAccountName
ModuleName = types.ModuleName ModuleName = types.ModuleName
PoolIDSep = types.PoolIDSep
QuerierRoute = types.QuerierRoute QuerierRoute = types.QuerierRoute
QueryGetDeposits = types.QueryGetDeposits QueryGetDeposits = types.QueryGetDeposits
QueryGetParams = types.QueryGetParams QueryGetParams = types.QueryGetParams
@ -35,8 +36,14 @@ const (
var ( var (
// function aliases // function aliases
AllInvariants = keeper.AllInvariants
NewKeeper = keeper.NewKeeper NewKeeper = keeper.NewKeeper
NewQuerier = keeper.NewQuerier NewQuerier = keeper.NewQuerier
PoolRecordsInvariant = keeper.PoolRecordsInvariant
PoolReservesInvariant = keeper.PoolReservesInvariant
PoolSharesInvariant = keeper.PoolSharesInvariant
RegisterInvariants = keeper.RegisterInvariants
ShareRecordsInvariant = keeper.ShareRecordsInvariant
DefaultGenesisState = types.DefaultGenesisState DefaultGenesisState = types.DefaultGenesisState
DefaultParams = types.DefaultParams DefaultParams = types.DefaultParams
DepositorPoolSharesKey = types.DepositorPoolSharesKey DepositorPoolSharesKey = types.DepositorPoolSharesKey

View File

@ -1,19 +0,0 @@
package keeper_test
import (
"testing"
//"github.com/kava-labs/kava/x/swap"
"github.com/kava-labs/kava/x/swap/testutil"
//"github.com/kava-labs/kava/x/swap/types"
"github.com/stretchr/testify/suite"
//sdk "github.com/cosmos/cosmos-sdk/types"
)
type invariantTestSuite struct {
testutil.Suite
}
func TestGenesisTestSuite(t *testing.T) {
suite.Run(t, new(invariantTestSuite))
}

138
x/swap/keeper/invariants.go Normal file
View File

@ -0,0 +1,138 @@
package keeper
import (
"github.com/kava-labs/kava/x/swap/types"
sdk "github.com/cosmos/cosmos-sdk/types"
)
// RegisterInvariants registers the swap module invariants
func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
ir.RegisterRoute(types.ModuleName, "pool-records", PoolRecordsInvariant(k))
ir.RegisterRoute(types.ModuleName, "share-records", ShareRecordsInvariant(k))
ir.RegisterRoute(types.ModuleName, "pool-reserves", PoolReservesInvariant(k))
ir.RegisterRoute(types.ModuleName, "pool-shares", PoolSharesInvariant(k))
}
// AllInvariants runs all invariants of the swap module
func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) (string, bool) {
if res, stop := PoolRecordsInvariant(k)(ctx); stop {
return res, stop
}
if res, stop := ShareRecordsInvariant(k)(ctx); stop {
return res, stop
}
if res, stop := PoolReservesInvariant(k)(ctx); stop {
return res, stop
}
res, stop := PoolSharesInvariant(k)(ctx)
return res, stop
}
}
// PoolRecordsInvariant iterates all pool records and asserts that they are valid
func PoolRecordsInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "validate pool records broken", "pool record invalid")
return func(ctx sdk.Context) (string, bool) {
k.IteratePools(ctx, func(record types.PoolRecord) bool {
if err := record.Validate(); err != nil {
broken = true
return true
}
return false
})
return message, broken
}
}
// ShareRecordsInvariant iterates all share records and asserts that they are valid
func ShareRecordsInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "validate share records broken", "share record invalid")
return func(ctx sdk.Context) (string, bool) {
k.IterateDepositorShares(ctx, func(record types.ShareRecord) bool {
if err := record.Validate(); err != nil {
broken = true
return true
}
return false
})
return message, broken
}
}
// PoolReservesInvariant iterates all pools and ensures the total reserves matches the module account coins
func PoolReservesInvariant(k Keeper) sdk.Invariant {
message := sdk.FormatInvariant(types.ModuleName, "pool reserves broken", "pool reserves do not match module account")
return func(ctx sdk.Context) (string, bool) {
mAcc := k.supplyKeeper.GetModuleAccount(ctx, types.ModuleName)
reserves := sdk.Coins{}
k.IteratePools(ctx, func(record types.PoolRecord) bool {
for _, coin := range record.Reserves() {
reserves = reserves.Add(coin)
}
return false
})
broken := !reserves.IsEqual(mAcc.GetCoins())
return message, broken
}
}
type poolShares struct {
totalShares sdk.Int
totalSharesOwned sdk.Int
}
// PoolSharesInvariant iterates all pools and shares and ensures the total pool shares match the sum of depositor shares
func PoolSharesInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "pool shares broken", "pool shares do not match depositor shares")
return func(ctx sdk.Context) (string, bool) {
totalShares := make(map[string]poolShares)
k.IteratePools(ctx, func(pr types.PoolRecord) bool {
totalShares[pr.PoolID] = poolShares{
totalShares: pr.TotalShares,
totalSharesOwned: sdk.ZeroInt(),
}
return false
})
k.IterateDepositorShares(ctx, func(sr types.ShareRecord) bool {
if shares, found := totalShares[sr.PoolID]; found {
shares.totalSharesOwned = shares.totalSharesOwned.Add(sr.SharesOwned)
totalShares[sr.PoolID] = shares
} else {
totalShares[sr.PoolID] = poolShares{
totalShares: sdk.ZeroInt(),
totalSharesOwned: sr.SharesOwned,
}
}
return false
})
for _, ps := range totalShares {
if !ps.totalShares.Equal(ps.totalSharesOwned) {
broken = true
break
}
}
return message, broken
}
}

View File

@ -0,0 +1,235 @@
package keeper_test
import (
"testing"
"github.com/kava-labs/kava/x/swap/keeper"
"github.com/kava-labs/kava/x/swap/testutil"
"github.com/kava-labs/kava/x/swap/types"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/suite"
)
type invariantTestSuite struct {
testutil.Suite
invariants map[string]map[string]sdk.Invariant
}
func (suite *invariantTestSuite) SetupTest() {
suite.Suite.SetupTest()
suite.invariants = make(map[string]map[string]sdk.Invariant)
keeper.RegisterInvariants(suite, suite.Keeper)
}
func (suite *invariantTestSuite) SetupValidState() {
suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(5e6)),
),
sdk.NewInt(3e6),
))
suite.AddCoinsToModule(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(5e6)),
),
)
suite.Keeper.SetDepositorShares(suite.Ctx, types.NewShareRecord(
sdk.AccAddress("depositor 1"),
types.PoolID("ukava", "usdx"),
sdk.NewInt(2e6),
))
suite.Keeper.SetDepositorShares(suite.Ctx, types.NewShareRecord(
sdk.AccAddress("depositor 2"),
types.PoolID("ukava", "usdx"),
sdk.NewInt(1e6),
))
suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("hard", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(2e6)),
),
sdk.NewInt(1e6),
))
suite.AddCoinsToModule(
sdk.NewCoins(
sdk.NewCoin("hard", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(2e6)),
),
)
suite.Keeper.SetDepositorShares(suite.Ctx, types.NewShareRecord(
sdk.AccAddress("depositor 1"),
types.PoolID("hard", "usdx"),
sdk.NewInt(1e6),
))
}
func (suite *invariantTestSuite) RegisterRoute(moduleName string, route string, invariant sdk.Invariant) {
_, exists := suite.invariants[moduleName]
if !exists {
suite.invariants[moduleName] = make(map[string]sdk.Invariant)
}
suite.invariants[moduleName][route] = invariant
}
func (suite *invariantTestSuite) runInvariant(route string, invariant func(k keeper.Keeper) sdk.Invariant) (string, bool) {
ctx := suite.Ctx
registeredInvariant := suite.invariants[types.ModuleName][route]
suite.Require().NotNil(registeredInvariant)
// direct call
dMessage, dBroken := invariant(suite.Keeper)(ctx)
// registered call
rMessage, rBroken := registeredInvariant(ctx)
// all call
aMessage, aBroken := keeper.AllInvariants(suite.Keeper)(ctx)
// require matching values for direct call and registered call
suite.Require().Equal(dMessage, rMessage, "expected registered invariant message to match")
suite.Require().Equal(dBroken, rBroken, "expected registered invariant broken to match")
// require matching values for direct call and all invariants call if broken
suite.Require().Equal(dBroken, aBroken, "expected all invariant broken to match")
if dBroken {
suite.Require().Equal(dMessage, aMessage, "expected all invariant message to match")
}
// return message, broken
return dMessage, dBroken
}
func (suite *invariantTestSuite) TestPoolRecordsInvariant() {
// default state is valid
message, broken := suite.runInvariant("pool-records", keeper.PoolRecordsInvariant)
suite.Equal("swap: validate pool records broken invariant\npool record invalid\n", message)
suite.Equal(false, broken)
suite.SetupValidState()
message, broken = suite.runInvariant("pool-records", keeper.PoolRecordsInvariant)
suite.Equal("swap: validate pool records broken invariant\npool record invalid\n", message)
suite.Equal(false, broken)
// broken with invalid pool record
suite.Keeper.SetPool_Raw(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(5e6)),
),
sdk.NewInt(-1e6),
))
message, broken = suite.runInvariant("pool-records", keeper.PoolRecordsInvariant)
suite.Equal("swap: validate pool records broken invariant\npool record invalid\n", message)
suite.Equal(true, broken)
}
func (suite *invariantTestSuite) TestShareRecordsInvariant() {
message, broken := suite.runInvariant("share-records", keeper.ShareRecordsInvariant)
suite.Equal("swap: validate share records broken invariant\nshare record invalid\n", message)
suite.Equal(false, broken)
suite.SetupValidState()
message, broken = suite.runInvariant("share-records", keeper.ShareRecordsInvariant)
suite.Equal("swap: validate share records broken invariant\nshare record invalid\n", message)
suite.Equal(false, broken)
// broken with invalid share record
suite.Keeper.SetDepositorShares_Raw(suite.Ctx, types.NewShareRecord(
sdk.AccAddress("depositor 1"),
types.PoolID("ukava", "usdx"),
sdk.NewInt(-1e6),
))
message, broken = suite.runInvariant("share-records", keeper.ShareRecordsInvariant)
suite.Equal("swap: validate share records broken invariant\nshare record invalid\n", message)
suite.Equal(true, broken)
}
func (suite *invariantTestSuite) TestPoolReservesInvariant() {
message, broken := suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant)
suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message)
suite.Equal(false, broken)
suite.SetupValidState()
message, broken = suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant)
suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message)
suite.Equal(false, broken)
// broken when reserves are greater than module balance
suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(2e6)),
sdk.NewCoin("usdx", sdk.NewInt(10e6)),
),
sdk.NewInt(5e6),
))
message, broken = suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant)
suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message)
suite.Equal(true, broken)
// broken when reserves are less than the module balance
suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e5)),
sdk.NewCoin("usdx", sdk.NewInt(5e5)),
),
sdk.NewInt(3e5),
))
message, broken = suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant)
suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message)
suite.Equal(true, broken)
}
func (suite *invariantTestSuite) TestPoolSharesInvariant() {
message, broken := suite.runInvariant("pool-shares", keeper.PoolSharesInvariant)
suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message)
suite.Equal(false, broken)
suite.SetupValidState()
message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant)
suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message)
suite.Equal(false, broken)
// broken when total shares are greater than depositor shares
suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(5e6)),
),
sdk.NewInt(5e6),
))
message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant)
suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message)
suite.Equal(true, broken)
// broken when total shares are less than the depositor shares
suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(5e6)),
),
sdk.NewInt(1e5),
))
message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant)
suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message)
suite.Equal(true, broken)
// broken when pool record is missing
suite.Keeper.DeletePool(suite.Ctx, types.PoolID("ukava", "usdx"))
suite.RemoveCoinsFromModule(
sdk.NewCoins(
sdk.NewCoin("ukava", sdk.NewInt(1e6)),
sdk.NewCoin("usdx", sdk.NewInt(5e6)),
),
)
message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant)
suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message)
suite.Equal(true, broken)
}
func TestInvariantTestSuite(t *testing.T) {
suite.Run(t, new(invariantTestSuite))
}

View File

@ -96,7 +96,9 @@ func (AppModule) Name() string {
} }
// RegisterInvariants register module invariants // RegisterInvariants register module invariants
func (AppModule) RegisterInvariants(_ sdk.InvariantRegistry) {} func (am AppModule) RegisterInvariants(ir sdk.InvariantRegistry) {
keeper.RegisterInvariants(ir, am.keeper)
}
// Route module message route name // Route module message route name
func (AppModule) Route() string { func (AppModule) Route() string {

40
x/swap/module_test.go Normal file
View File

@ -0,0 +1,40 @@
package swap_test
import (
"testing"
"github.com/kava-labs/kava/x/swap"
"github.com/kava-labs/kava/x/swap/testutil"
"github.com/cosmos/cosmos-sdk/x/crisis"
"github.com/stretchr/testify/suite"
)
type moduleTestSuite struct {
testutil.Suite
crisisKeeper crisis.Keeper
}
func (suite *moduleTestSuite) SetupTest() {
suite.Suite.SetupTest()
suite.crisisKeeper = suite.App.GetCrisisKeeper()
}
func (suite *moduleTestSuite) TestRegisterInvariants() {
swapRoutes := []string{}
for _, route := range suite.crisisKeeper.Routes() {
if route.ModuleName == swap.ModuleName {
swapRoutes = append(swapRoutes, route.Route)
}
}
suite.Contains(swapRoutes, "pool-records")
suite.Contains(swapRoutes, "share-records")
suite.Contains(swapRoutes, "pool-reserves")
suite.Contains(swapRoutes, "pool-shares")
}
func TestModuleTestSuite(t *testing.T) {
suite.Run(t, new(moduleTestSuite))
}