From b86cfc9f14bea5d46c2d8b0b07c1269094a21699 Mon Sep 17 00:00:00 2001 From: Nick DeLuca Date: Thu, 5 Aug 2021 20:43:55 -0500 Subject: [PATCH] Add Invariants for Swap Module (#979) * add swap module invariants * typo * update alias file for invariants * typo in test name * fix typo - method iterates share record, not pools --- x/swap/alias.go | 7 + x/swap/keeper/invariant_test.go | 19 --- x/swap/keeper/invariants.go | 138 ++++++++++++++++++ x/swap/keeper/invariants_test.go | 235 +++++++++++++++++++++++++++++++ x/swap/module.go | 4 +- x/swap/module_test.go | 40 ++++++ 6 files changed, 423 insertions(+), 20 deletions(-) delete mode 100644 x/swap/keeper/invariant_test.go create mode 100644 x/swap/keeper/invariants.go create mode 100644 x/swap/keeper/invariants_test.go create mode 100644 x/swap/module_test.go diff --git a/x/swap/alias.go b/x/swap/alias.go index 76b95c84..e0db2d18 100644 --- a/x/swap/alias.go +++ b/x/swap/alias.go @@ -24,6 +24,7 @@ const ( EventTypeSwapWithdraw = types.EventTypeSwapWithdraw ModuleAccountName = types.ModuleAccountName ModuleName = types.ModuleName + PoolIDSep = types.PoolIDSep QuerierRoute = types.QuerierRoute QueryGetDeposits = types.QueryGetDeposits QueryGetParams = types.QueryGetParams @@ -35,8 +36,14 @@ const ( var ( // function aliases + AllInvariants = keeper.AllInvariants NewKeeper = keeper.NewKeeper NewQuerier = keeper.NewQuerier + PoolRecordsInvariant = keeper.PoolRecordsInvariant + PoolReservesInvariant = keeper.PoolReservesInvariant + PoolSharesInvariant = keeper.PoolSharesInvariant + RegisterInvariants = keeper.RegisterInvariants + ShareRecordsInvariant = keeper.ShareRecordsInvariant DefaultGenesisState = types.DefaultGenesisState DefaultParams = types.DefaultParams DepositorPoolSharesKey = types.DepositorPoolSharesKey diff --git a/x/swap/keeper/invariant_test.go b/x/swap/keeper/invariant_test.go deleted file mode 100644 index ab39f78e..00000000 --- a/x/swap/keeper/invariant_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package keeper_test - -import ( - "testing" - - //"github.com/kava-labs/kava/x/swap" - "github.com/kava-labs/kava/x/swap/testutil" - //"github.com/kava-labs/kava/x/swap/types" - "github.com/stretchr/testify/suite" - //sdk "github.com/cosmos/cosmos-sdk/types" -) - -type invariantTestSuite struct { - testutil.Suite -} - -func TestGenesisTestSuite(t *testing.T) { - suite.Run(t, new(invariantTestSuite)) -} diff --git a/x/swap/keeper/invariants.go b/x/swap/keeper/invariants.go new file mode 100644 index 00000000..34e2e28d --- /dev/null +++ b/x/swap/keeper/invariants.go @@ -0,0 +1,138 @@ +package keeper + +import ( + "github.com/kava-labs/kava/x/swap/types" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// RegisterInvariants registers the swap module invariants +func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) { + ir.RegisterRoute(types.ModuleName, "pool-records", PoolRecordsInvariant(k)) + ir.RegisterRoute(types.ModuleName, "share-records", ShareRecordsInvariant(k)) + ir.RegisterRoute(types.ModuleName, "pool-reserves", PoolReservesInvariant(k)) + ir.RegisterRoute(types.ModuleName, "pool-shares", PoolSharesInvariant(k)) +} + +// AllInvariants runs all invariants of the swap module +func AllInvariants(k Keeper) sdk.Invariant { + return func(ctx sdk.Context) (string, bool) { + if res, stop := PoolRecordsInvariant(k)(ctx); stop { + return res, stop + } + + if res, stop := ShareRecordsInvariant(k)(ctx); stop { + return res, stop + } + + if res, stop := PoolReservesInvariant(k)(ctx); stop { + return res, stop + } + + res, stop := PoolSharesInvariant(k)(ctx) + return res, stop + } +} + +// PoolRecordsInvariant iterates all pool records and asserts that they are valid +func PoolRecordsInvariant(k Keeper) sdk.Invariant { + broken := false + message := sdk.FormatInvariant(types.ModuleName, "validate pool records broken", "pool record invalid") + + return func(ctx sdk.Context) (string, bool) { + k.IteratePools(ctx, func(record types.PoolRecord) bool { + if err := record.Validate(); err != nil { + broken = true + return true + } + return false + }) + + return message, broken + } +} + +// ShareRecordsInvariant iterates all share records and asserts that they are valid +func ShareRecordsInvariant(k Keeper) sdk.Invariant { + broken := false + message := sdk.FormatInvariant(types.ModuleName, "validate share records broken", "share record invalid") + + return func(ctx sdk.Context) (string, bool) { + k.IterateDepositorShares(ctx, func(record types.ShareRecord) bool { + if err := record.Validate(); err != nil { + broken = true + return true + } + return false + }) + + return message, broken + } +} + +// PoolReservesInvariant iterates all pools and ensures the total reserves matches the module account coins +func PoolReservesInvariant(k Keeper) sdk.Invariant { + message := sdk.FormatInvariant(types.ModuleName, "pool reserves broken", "pool reserves do not match module account") + + return func(ctx sdk.Context) (string, bool) { + mAcc := k.supplyKeeper.GetModuleAccount(ctx, types.ModuleName) + + reserves := sdk.Coins{} + k.IteratePools(ctx, func(record types.PoolRecord) bool { + for _, coin := range record.Reserves() { + reserves = reserves.Add(coin) + } + return false + }) + + broken := !reserves.IsEqual(mAcc.GetCoins()) + return message, broken + } +} + +type poolShares struct { + totalShares sdk.Int + totalSharesOwned sdk.Int +} + +// PoolSharesInvariant iterates all pools and shares and ensures the total pool shares match the sum of depositor shares +func PoolSharesInvariant(k Keeper) sdk.Invariant { + broken := false + message := sdk.FormatInvariant(types.ModuleName, "pool shares broken", "pool shares do not match depositor shares") + + return func(ctx sdk.Context) (string, bool) { + totalShares := make(map[string]poolShares) + + k.IteratePools(ctx, func(pr types.PoolRecord) bool { + totalShares[pr.PoolID] = poolShares{ + totalShares: pr.TotalShares, + totalSharesOwned: sdk.ZeroInt(), + } + + return false + }) + + k.IterateDepositorShares(ctx, func(sr types.ShareRecord) bool { + if shares, found := totalShares[sr.PoolID]; found { + shares.totalSharesOwned = shares.totalSharesOwned.Add(sr.SharesOwned) + totalShares[sr.PoolID] = shares + } else { + totalShares[sr.PoolID] = poolShares{ + totalShares: sdk.ZeroInt(), + totalSharesOwned: sr.SharesOwned, + } + } + + return false + }) + + for _, ps := range totalShares { + if !ps.totalShares.Equal(ps.totalSharesOwned) { + broken = true + break + } + } + + return message, broken + } +} diff --git a/x/swap/keeper/invariants_test.go b/x/swap/keeper/invariants_test.go new file mode 100644 index 00000000..cb93ae66 --- /dev/null +++ b/x/swap/keeper/invariants_test.go @@ -0,0 +1,235 @@ +package keeper_test + +import ( + "testing" + + "github.com/kava-labs/kava/x/swap/keeper" + "github.com/kava-labs/kava/x/swap/testutil" + "github.com/kava-labs/kava/x/swap/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/suite" +) + +type invariantTestSuite struct { + testutil.Suite + invariants map[string]map[string]sdk.Invariant +} + +func (suite *invariantTestSuite) SetupTest() { + suite.Suite.SetupTest() + suite.invariants = make(map[string]map[string]sdk.Invariant) + keeper.RegisterInvariants(suite, suite.Keeper) +} + +func (suite *invariantTestSuite) SetupValidState() { + suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(5e6)), + ), + sdk.NewInt(3e6), + )) + suite.AddCoinsToModule( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(5e6)), + ), + ) + suite.Keeper.SetDepositorShares(suite.Ctx, types.NewShareRecord( + sdk.AccAddress("depositor 1"), + types.PoolID("ukava", "usdx"), + sdk.NewInt(2e6), + )) + suite.Keeper.SetDepositorShares(suite.Ctx, types.NewShareRecord( + sdk.AccAddress("depositor 2"), + types.PoolID("ukava", "usdx"), + sdk.NewInt(1e6), + )) + + suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("hard", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(2e6)), + ), + sdk.NewInt(1e6), + )) + suite.AddCoinsToModule( + sdk.NewCoins( + sdk.NewCoin("hard", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(2e6)), + ), + ) + suite.Keeper.SetDepositorShares(suite.Ctx, types.NewShareRecord( + sdk.AccAddress("depositor 1"), + types.PoolID("hard", "usdx"), + sdk.NewInt(1e6), + )) +} + +func (suite *invariantTestSuite) RegisterRoute(moduleName string, route string, invariant sdk.Invariant) { + _, exists := suite.invariants[moduleName] + + if !exists { + suite.invariants[moduleName] = make(map[string]sdk.Invariant) + } + + suite.invariants[moduleName][route] = invariant +} + +func (suite *invariantTestSuite) runInvariant(route string, invariant func(k keeper.Keeper) sdk.Invariant) (string, bool) { + ctx := suite.Ctx + registeredInvariant := suite.invariants[types.ModuleName][route] + suite.Require().NotNil(registeredInvariant) + + // direct call + dMessage, dBroken := invariant(suite.Keeper)(ctx) + // registered call + rMessage, rBroken := registeredInvariant(ctx) + // all call + aMessage, aBroken := keeper.AllInvariants(suite.Keeper)(ctx) + + // require matching values for direct call and registered call + suite.Require().Equal(dMessage, rMessage, "expected registered invariant message to match") + suite.Require().Equal(dBroken, rBroken, "expected registered invariant broken to match") + // require matching values for direct call and all invariants call if broken + suite.Require().Equal(dBroken, aBroken, "expected all invariant broken to match") + if dBroken { + suite.Require().Equal(dMessage, aMessage, "expected all invariant message to match") + } + + // return message, broken + return dMessage, dBroken +} + +func (suite *invariantTestSuite) TestPoolRecordsInvariant() { + + // default state is valid + message, broken := suite.runInvariant("pool-records", keeper.PoolRecordsInvariant) + suite.Equal("swap: validate pool records broken invariant\npool record invalid\n", message) + suite.Equal(false, broken) + + suite.SetupValidState() + message, broken = suite.runInvariant("pool-records", keeper.PoolRecordsInvariant) + suite.Equal("swap: validate pool records broken invariant\npool record invalid\n", message) + suite.Equal(false, broken) + + // broken with invalid pool record + suite.Keeper.SetPool_Raw(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(5e6)), + ), + sdk.NewInt(-1e6), + )) + message, broken = suite.runInvariant("pool-records", keeper.PoolRecordsInvariant) + suite.Equal("swap: validate pool records broken invariant\npool record invalid\n", message) + suite.Equal(true, broken) +} + +func (suite *invariantTestSuite) TestShareRecordsInvariant() { + message, broken := suite.runInvariant("share-records", keeper.ShareRecordsInvariant) + suite.Equal("swap: validate share records broken invariant\nshare record invalid\n", message) + suite.Equal(false, broken) + + suite.SetupValidState() + message, broken = suite.runInvariant("share-records", keeper.ShareRecordsInvariant) + suite.Equal("swap: validate share records broken invariant\nshare record invalid\n", message) + suite.Equal(false, broken) + + // broken with invalid share record + suite.Keeper.SetDepositorShares_Raw(suite.Ctx, types.NewShareRecord( + sdk.AccAddress("depositor 1"), + types.PoolID("ukava", "usdx"), + sdk.NewInt(-1e6), + )) + message, broken = suite.runInvariant("share-records", keeper.ShareRecordsInvariant) + suite.Equal("swap: validate share records broken invariant\nshare record invalid\n", message) + suite.Equal(true, broken) +} + +func (suite *invariantTestSuite) TestPoolReservesInvariant() { + message, broken := suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant) + suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message) + suite.Equal(false, broken) + + suite.SetupValidState() + message, broken = suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant) + suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message) + suite.Equal(false, broken) + + // broken when reserves are greater than module balance + suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(2e6)), + sdk.NewCoin("usdx", sdk.NewInt(10e6)), + ), + sdk.NewInt(5e6), + )) + message, broken = suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant) + suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message) + suite.Equal(true, broken) + + // broken when reserves are less than the module balance + suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e5)), + sdk.NewCoin("usdx", sdk.NewInt(5e5)), + ), + sdk.NewInt(3e5), + )) + message, broken = suite.runInvariant("pool-reserves", keeper.PoolReservesInvariant) + suite.Equal("swap: pool reserves broken invariant\npool reserves do not match module account\n", message) + suite.Equal(true, broken) +} + +func (suite *invariantTestSuite) TestPoolSharesInvariant() { + message, broken := suite.runInvariant("pool-shares", keeper.PoolSharesInvariant) + suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message) + suite.Equal(false, broken) + + suite.SetupValidState() + message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant) + suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message) + suite.Equal(false, broken) + + // broken when total shares are greater than depositor shares + suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(5e6)), + ), + sdk.NewInt(5e6), + )) + message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant) + suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message) + suite.Equal(true, broken) + + // broken when total shares are less than the depositor shares + suite.Keeper.SetPool(suite.Ctx, types.NewPoolRecord( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(5e6)), + ), + sdk.NewInt(1e5), + )) + message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant) + suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message) + suite.Equal(true, broken) + + // broken when pool record is missing + suite.Keeper.DeletePool(suite.Ctx, types.PoolID("ukava", "usdx")) + suite.RemoveCoinsFromModule( + sdk.NewCoins( + sdk.NewCoin("ukava", sdk.NewInt(1e6)), + sdk.NewCoin("usdx", sdk.NewInt(5e6)), + ), + ) + message, broken = suite.runInvariant("pool-shares", keeper.PoolSharesInvariant) + suite.Equal("swap: pool shares broken invariant\npool shares do not match depositor shares\n", message) + suite.Equal(true, broken) +} + +func TestInvariantTestSuite(t *testing.T) { + suite.Run(t, new(invariantTestSuite)) +} diff --git a/x/swap/module.go b/x/swap/module.go index 78169628..417d6c26 100644 --- a/x/swap/module.go +++ b/x/swap/module.go @@ -96,7 +96,9 @@ func (AppModule) Name() string { } // RegisterInvariants register module invariants -func (AppModule) RegisterInvariants(_ sdk.InvariantRegistry) {} +func (am AppModule) RegisterInvariants(ir sdk.InvariantRegistry) { + keeper.RegisterInvariants(ir, am.keeper) +} // Route module message route name func (AppModule) Route() string { diff --git a/x/swap/module_test.go b/x/swap/module_test.go new file mode 100644 index 00000000..d6e4e5e6 --- /dev/null +++ b/x/swap/module_test.go @@ -0,0 +1,40 @@ +package swap_test + +import ( + "testing" + + "github.com/kava-labs/kava/x/swap" + "github.com/kava-labs/kava/x/swap/testutil" + + "github.com/cosmos/cosmos-sdk/x/crisis" + "github.com/stretchr/testify/suite" +) + +type moduleTestSuite struct { + testutil.Suite + crisisKeeper crisis.Keeper +} + +func (suite *moduleTestSuite) SetupTest() { + suite.Suite.SetupTest() + suite.crisisKeeper = suite.App.GetCrisisKeeper() +} + +func (suite *moduleTestSuite) TestRegisterInvariants() { + swapRoutes := []string{} + + for _, route := range suite.crisisKeeper.Routes() { + if route.ModuleName == swap.ModuleName { + swapRoutes = append(swapRoutes, route.Route) + } + } + + suite.Contains(swapRoutes, "pool-records") + suite.Contains(swapRoutes, "share-records") + suite.Contains(swapRoutes, "pool-reserves") + suite.Contains(swapRoutes, "pool-shares") +} + +func TestModuleTestSuite(t *testing.T) { + suite.Run(t, new(moduleTestSuite)) +}