mirror of
https://github.com/0glabs/0g-chain.git
synced 2025-01-24 22:15:17 +00:00
feat(x/precisebank): Add keeper methods for store (#1912)
- Add store methods to get/set/delete/etc account fractional balances & remainder amount - Add invariants to ensure stored state is correct
This commit is contained in:
parent
d66b7d2705
commit
4ff43eb270
109
x/precisebank/keeper/fractional_balance.go
Normal file
109
x/precisebank/keeper/fractional_balance.go
Normal file
@ -0,0 +1,109 @@
|
||||
package keeper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
"github.com/cosmos/cosmos-sdk/store/prefix"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
)
|
||||
|
||||
// GetFractionalBalance returns the fractional balance for an address.
|
||||
func (k *Keeper) GetFractionalBalance(
|
||||
ctx sdk.Context,
|
||||
address sdk.AccAddress,
|
||||
) (sdkmath.Int, bool) {
|
||||
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.FractionalBalancePrefix)
|
||||
|
||||
bz := store.Get(types.FractionalBalanceKey(address))
|
||||
if bz == nil {
|
||||
return sdkmath.ZeroInt(), false
|
||||
}
|
||||
|
||||
var bal sdkmath.Int
|
||||
if err := bal.Unmarshal(bz); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal fractional balance: %w", err))
|
||||
}
|
||||
|
||||
return bal, true
|
||||
}
|
||||
|
||||
// SetFractionalBalance sets the fractional balance for an address.
|
||||
func (k *Keeper) SetFractionalBalance(
|
||||
ctx sdk.Context,
|
||||
address sdk.AccAddress,
|
||||
amount sdkmath.Int,
|
||||
) {
|
||||
if address.Empty() {
|
||||
panic(errors.New("address cannot be empty"))
|
||||
}
|
||||
|
||||
if amount.IsZero() {
|
||||
k.DeleteFractionalBalance(ctx, address)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the fractional balance is valid before setting it. Use the
|
||||
// NewFractionalAmountFromInt wrapper to use its Validate() method.
|
||||
if err := types.NewFractionalAmountFromInt(amount).Validate(); err != nil {
|
||||
panic(fmt.Errorf("amount is invalid: %w", err))
|
||||
}
|
||||
|
||||
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.FractionalBalancePrefix)
|
||||
|
||||
amountBytes, err := amount.Marshal()
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal fractional balance: %w", err))
|
||||
}
|
||||
|
||||
store.Set(types.FractionalBalanceKey(address), amountBytes)
|
||||
}
|
||||
|
||||
// DeleteFractionalBalance deletes the fractional balance for an address.
|
||||
func (k *Keeper) DeleteFractionalBalance(
|
||||
ctx sdk.Context,
|
||||
address sdk.AccAddress,
|
||||
) {
|
||||
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.FractionalBalancePrefix)
|
||||
store.Delete(types.FractionalBalanceKey(address))
|
||||
}
|
||||
|
||||
// IterateFractionalBalances iterates over all fractional balances in the store
|
||||
// and performs a callback function.
|
||||
func (k *Keeper) IterateFractionalBalances(
|
||||
ctx sdk.Context,
|
||||
cb func(address sdk.AccAddress, amount sdkmath.Int) (stop bool),
|
||||
) {
|
||||
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.FractionalBalancePrefix)
|
||||
|
||||
iterator := store.Iterator(nil, nil)
|
||||
defer iterator.Close()
|
||||
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
address := sdk.AccAddress(iterator.Key())
|
||||
|
||||
var amount sdkmath.Int
|
||||
if err := amount.Unmarshal(iterator.Value()); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal fractional balance: %w", err))
|
||||
}
|
||||
|
||||
if cb(address, amount) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalSumFractionalBalances returns the sum of all fractional balances.
|
||||
func (k *Keeper) GetTotalSumFractionalBalances(ctx sdk.Context) sdkmath.Int {
|
||||
sum := sdkmath.ZeroInt()
|
||||
|
||||
k.IterateFractionalBalances(ctx, func(_ sdk.AccAddress, amount sdkmath.Int) bool {
|
||||
sum = sum.Add(amount)
|
||||
return false
|
||||
})
|
||||
|
||||
return sum
|
||||
}
|
189
x/precisebank/keeper/fractional_balance_test.go
Normal file
189
x/precisebank/keeper/fractional_balance_test.go
Normal file
@ -0,0 +1,189 @@
|
||||
package keeper_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
)
|
||||
|
||||
func TestSetGetFractionalBalance(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k := tk.ctx, tk.keeper
|
||||
|
||||
addr := sdk.AccAddress([]byte("test-address"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
address sdk.AccAddress
|
||||
amount sdkmath.Int
|
||||
setPanicMsg string
|
||||
}{
|
||||
{
|
||||
"valid - min amount",
|
||||
addr,
|
||||
sdkmath.NewInt(1),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"valid - positive amount",
|
||||
addr,
|
||||
sdkmath.NewInt(100),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"valid - max amount",
|
||||
addr,
|
||||
types.ConversionFactor().SubRaw(1),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"valid - zero amount (deletes)",
|
||||
addr,
|
||||
sdkmath.ZeroInt(),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid - negative amount",
|
||||
addr,
|
||||
sdkmath.NewInt(-1),
|
||||
"amount is invalid: non-positive amount -1",
|
||||
},
|
||||
{
|
||||
"invalid - over max amount",
|
||||
addr,
|
||||
types.ConversionFactor(),
|
||||
"amount is invalid: amount 1000000000000 exceeds max of 999999999999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setPanicMsg != "" {
|
||||
require.PanicsWithError(t, tt.setPanicMsg, func() {
|
||||
k.SetFractionalBalance(ctx, tt.address, tt.amount)
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
k.SetFractionalBalance(ctx, tt.address, tt.amount)
|
||||
})
|
||||
|
||||
// If its zero balance, check it was deleted
|
||||
if tt.amount.IsZero() {
|
||||
_, exists := k.GetFractionalBalance(ctx, tt.address)
|
||||
require.False(t, exists)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
gotAmount, exists := k.GetFractionalBalance(ctx, tt.address)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, tt.amount, gotAmount)
|
||||
|
||||
// Delete balance
|
||||
k.DeleteFractionalBalance(ctx, tt.address)
|
||||
|
||||
_, exists = k.GetFractionalBalance(ctx, tt.address)
|
||||
require.False(t, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetFractionalBalance_InvalidAddr(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k := tk.ctx, tk.keeper
|
||||
|
||||
require.PanicsWithError(
|
||||
t,
|
||||
"address cannot be empty",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{}, sdkmath.NewInt(100))
|
||||
},
|
||||
"setting balance with empty address should panic",
|
||||
)
|
||||
}
|
||||
|
||||
func TestSetFractionalBalance_ZeroDeletes(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k := tk.ctx, tk.keeper
|
||||
|
||||
addr := sdk.AccAddress([]byte("test-address"))
|
||||
|
||||
// Set balance
|
||||
k.SetFractionalBalance(ctx, addr, sdkmath.NewInt(100))
|
||||
|
||||
bal, exists := k.GetFractionalBalance(ctx, addr)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, sdkmath.NewInt(100), bal)
|
||||
|
||||
// Set zero balance
|
||||
k.SetFractionalBalance(ctx, addr, sdkmath.ZeroInt())
|
||||
|
||||
_, exists = k.GetFractionalBalance(ctx, addr)
|
||||
require.False(t, exists)
|
||||
|
||||
// Set zero balance again on non-existent balance
|
||||
require.NotPanics(
|
||||
t,
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, addr, sdkmath.ZeroInt())
|
||||
},
|
||||
"deleting non-existent balance should not panic",
|
||||
)
|
||||
}
|
||||
|
||||
func TestIterateFractionalBalances(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k := tk.ctx, tk.keeper
|
||||
|
||||
addrs := []sdk.AccAddress{}
|
||||
|
||||
for i := 1; i < 10; i++ {
|
||||
addr := sdk.AccAddress([]byte{byte(i)})
|
||||
addrs = append(addrs, addr)
|
||||
|
||||
// Set balance same as their address byte
|
||||
k.SetFractionalBalance(ctx, addr, sdkmath.NewInt(int64(i)))
|
||||
}
|
||||
|
||||
seenAddrs := []sdk.AccAddress{}
|
||||
|
||||
k.IterateFractionalBalances(ctx, func(addr sdk.AccAddress, bal sdkmath.Int) bool {
|
||||
seenAddrs = append(seenAddrs, addr)
|
||||
|
||||
// Balance is same as first address byte
|
||||
require.Equal(t, int64(addr.Bytes()[0]), bal.Int64())
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
require.ElementsMatch(t, addrs, seenAddrs, "all addresses should be seen")
|
||||
}
|
||||
|
||||
func TestGetAggregateSumFractionalBalances(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k := tk.ctx, tk.keeper
|
||||
|
||||
// Set balances from 1 to 10
|
||||
sum := sdkmath.ZeroInt()
|
||||
for i := 1; i < 10; i++ {
|
||||
addr := sdk.AccAddress([]byte{byte(i)})
|
||||
amt := sdkmath.NewInt(int64(i))
|
||||
|
||||
sum = sum.Add(amt)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
k.SetFractionalBalance(ctx, addr, amt)
|
||||
})
|
||||
}
|
||||
|
||||
gotSum := k.GetTotalSumFractionalBalances(ctx)
|
||||
require.Equal(t, sum, gotSum)
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
package keeper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
)
|
||||
@ -11,11 +14,110 @@ func RegisterInvariants(
|
||||
k Keeper,
|
||||
bk types.BankKeeper,
|
||||
) {
|
||||
ir.RegisterRoute(types.ModuleName, "balance-remainder-total", BalancedFractionalTotalInvariant(k))
|
||||
ir.RegisterRoute(types.ModuleName, "valid-fractional-balances", ValidFractionalAmountsInvariant(k))
|
||||
ir.RegisterRoute(types.ModuleName, "valid-remainder-amount", ValidRemainderAmountInvariant(k))
|
||||
}
|
||||
|
||||
// AllInvariants runs all invariants of the X/precisebank module.
|
||||
func AllInvariants(k Keeper) sdk.Invariant {
|
||||
return func(ctx sdk.Context) (string, bool) {
|
||||
res, stop := BalancedFractionalTotalInvariant(k)(ctx)
|
||||
if stop {
|
||||
return res, stop
|
||||
}
|
||||
|
||||
res, stop = ValidFractionalAmountsInvariant(k)(ctx)
|
||||
if stop {
|
||||
return res, stop
|
||||
}
|
||||
|
||||
res, stop = ValidRemainderAmountInvariant(k)(ctx)
|
||||
if stop {
|
||||
return res, stop
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// ValidFractionalAmountsInvariant checks that all individual fractional
|
||||
// balances are valid.
|
||||
func ValidFractionalAmountsInvariant(k Keeper) sdk.Invariant {
|
||||
return func(ctx sdk.Context) (string, bool) {
|
||||
var (
|
||||
msg string
|
||||
count int
|
||||
)
|
||||
|
||||
k.IterateFractionalBalances(ctx, func(addr sdk.AccAddress, amount sdkmath.Int) bool {
|
||||
if err := types.NewFractionalAmountFromInt(amount).Validate(); err != nil {
|
||||
count++
|
||||
msg += fmt.Sprintf("\t%s has an invalid fractional amount of %s\n", addr, amount)
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
broken := count != 0
|
||||
|
||||
return sdk.FormatInvariant(
|
||||
types.ModuleName, "valid-fractional-balances",
|
||||
fmt.Sprintf("amount of invalid fractional balances found %d\n%s", count, msg),
|
||||
), broken
|
||||
}
|
||||
}
|
||||
|
||||
// ValidRemainderAmountInvariant checks that the remainder amount is valid.
|
||||
func ValidRemainderAmountInvariant(k Keeper) sdk.Invariant {
|
||||
return func(ctx sdk.Context) (string, bool) {
|
||||
var (
|
||||
msg string
|
||||
broken bool
|
||||
)
|
||||
|
||||
remainderAmount := k.GetRemainderAmount(ctx)
|
||||
|
||||
if !remainderAmount.IsZero() {
|
||||
// Only validate if non-zero, as zero is default value
|
||||
if err := types.NewFractionalAmountFromInt(remainderAmount).Validate(); err != nil {
|
||||
broken = true
|
||||
msg = fmt.Sprintf("remainder amount is invalid: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return sdk.FormatInvariant(
|
||||
types.ModuleName, "valid-remainder-amount",
|
||||
msg,
|
||||
), broken
|
||||
}
|
||||
}
|
||||
|
||||
// BalancedFractionalTotalInvariant checks that the sum of fractional balances
|
||||
// and the remainder amount is divisible by the conversion factor without any
|
||||
// leftover amount.
|
||||
func BalancedFractionalTotalInvariant(k Keeper) sdk.Invariant {
|
||||
return func(ctx sdk.Context) (string, bool) {
|
||||
fractionalBalSum := k.GetTotalSumFractionalBalances(ctx)
|
||||
remainderAmount := k.GetRemainderAmount(ctx)
|
||||
|
||||
total := fractionalBalSum.Add(remainderAmount)
|
||||
fractionalAmount := total.Mod(types.ConversionFactor())
|
||||
|
||||
broken := false
|
||||
msg := ""
|
||||
|
||||
if !fractionalAmount.IsZero() {
|
||||
broken = true
|
||||
msg = fmt.Sprintf(
|
||||
"(sum(FractionalBalances) + remainder) %% conversionFactor should be 0 but got %v",
|
||||
fractionalAmount,
|
||||
)
|
||||
}
|
||||
|
||||
return sdk.FormatInvariant(
|
||||
types.ModuleName, "balance-remainder-total",
|
||||
msg,
|
||||
), broken
|
||||
}
|
||||
}
|
||||
|
170
x/precisebank/keeper/invariants_test.go
Normal file
170
x/precisebank/keeper/invariants_test.go
Normal file
@ -0,0 +1,170 @@
|
||||
package keeper_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
"github.com/cosmos/cosmos-sdk/store/prefix"
|
||||
storetypes "github.com/cosmos/cosmos-sdk/store/types"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
|
||||
"github.com/kava-labs/kava/x/precisebank/keeper"
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBalancedFractionalTotalInvariant(t *testing.T) {
|
||||
var ctx sdk.Context
|
||||
var k keeper.Keeper
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFn func()
|
||||
wantBroken bool
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
"valid - empty state",
|
||||
func() {},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"valid - balances, 0 remainder",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{1}, types.ConversionFactor().QuoRaw(2))
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{2}, types.ConversionFactor().QuoRaw(2))
|
||||
},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"valid - balances, non-zero remainder",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{1}, types.ConversionFactor().QuoRaw(2))
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{2}, types.ConversionFactor().QuoRaw(2).SubRaw(1))
|
||||
|
||||
k.SetRemainderAmount(ctx, sdkmath.OneInt())
|
||||
},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid - balances, 0 remainder",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{1}, types.ConversionFactor().QuoRaw(2))
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{2}, types.ConversionFactor().QuoRaw(2).SubRaw(1))
|
||||
},
|
||||
true,
|
||||
"precisebank: balance-remainder-total invariant\n(sum(FractionalBalances) + remainder) % conversionFactor should be 0 but got 999999999999\n",
|
||||
},
|
||||
{
|
||||
"invalid - invalid balances, non-zero (insufficient) remainder",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{1}, types.ConversionFactor().QuoRaw(2))
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{2}, types.ConversionFactor().QuoRaw(2).SubRaw(2))
|
||||
k.SetRemainderAmount(ctx, sdkmath.OneInt())
|
||||
},
|
||||
true,
|
||||
"precisebank: balance-remainder-total invariant\n(sum(FractionalBalances) + remainder) % conversionFactor should be 0 but got 999999999999\n",
|
||||
},
|
||||
{
|
||||
"invalid - invalid balances, non-zero (excess) remainder",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{1}, types.ConversionFactor().QuoRaw(2))
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{2}, types.ConversionFactor().QuoRaw(2).SubRaw(2))
|
||||
k.SetRemainderAmount(ctx, sdkmath.NewInt(5))
|
||||
},
|
||||
true,
|
||||
"precisebank: balance-remainder-total invariant\n(sum(FractionalBalances) + remainder) % conversionFactor should be 0 but got 3\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset each time
|
||||
tk := NewTestKeeper()
|
||||
ctx, k = tk.ctx, tk.keeper
|
||||
|
||||
tt.setupFn()
|
||||
|
||||
invariantFn := keeper.BalancedFractionalTotalInvariant(k)
|
||||
msg, broken := invariantFn(ctx)
|
||||
|
||||
if tt.wantBroken {
|
||||
require.True(t, broken, "invariant should be broken but is not")
|
||||
require.Equal(t, tt.wantMsg, msg)
|
||||
} else {
|
||||
require.False(t, broken, "invariant should not be broken but is")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidFractionalAmountsInvariant(t *testing.T) {
|
||||
var ctx sdk.Context
|
||||
var k keeper.Keeper
|
||||
var storeKey storetypes.StoreKey
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFn func()
|
||||
wantBroken bool
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
"valid - empty state",
|
||||
func() {},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"valid - valid balances",
|
||||
func() {
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{1}, types.ConversionFactor().QuoRaw(2))
|
||||
k.SetFractionalBalance(ctx, sdk.AccAddress{2}, types.ConversionFactor().QuoRaw(2))
|
||||
},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid - exceeds max balance",
|
||||
func() {
|
||||
// Requires manual store manipulation so it is unlikely to have
|
||||
// invalid state in practice. SetFractionalBalance will validate
|
||||
// before setting.
|
||||
addr := sdk.AccAddress{1}
|
||||
amount := types.ConversionFactor()
|
||||
|
||||
store := prefix.NewStore(ctx.KVStore(storeKey), types.FractionalBalancePrefix)
|
||||
|
||||
amountBytes, err := amount.Marshal()
|
||||
require.NoError(t, err)
|
||||
|
||||
store.Set(types.FractionalBalanceKey(addr), amountBytes)
|
||||
},
|
||||
true,
|
||||
"precisebank: valid-fractional-balances invariant\namount of invalid fractional balances found 1\n\tkava1qy0xn7za has an invalid fractional amount of 1000000000000\n\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset each time
|
||||
tk := NewTestKeeper()
|
||||
ctx, k, storeKey = tk.ctx, tk.keeper, tk.storeKey
|
||||
|
||||
tt.setupFn()
|
||||
|
||||
invariantFn := keeper.ValidFractionalAmountsInvariant(k)
|
||||
msg, broken := invariantFn(ctx)
|
||||
|
||||
if tt.wantBroken {
|
||||
require.True(t, broken, "invariant should be broken but is not")
|
||||
require.Equal(t, tt.wantMsg, msg)
|
||||
} else {
|
||||
require.False(t, broken, "invariant should not be broken but is")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
36
x/precisebank/keeper/keeper_test.go
Normal file
36
x/precisebank/keeper/keeper_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package keeper_test
|
||||
|
||||
import (
|
||||
storetypes "github.com/cosmos/cosmos-sdk/store/types"
|
||||
"github.com/cosmos/cosmos-sdk/testutil"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
|
||||
"github.com/kava-labs/kava/app"
|
||||
"github.com/kava-labs/kava/x/precisebank/keeper"
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
)
|
||||
|
||||
// testKeeper defines necessary fields for testing keeper store methods that
|
||||
// don't require a full app setup.
|
||||
type testKeeper struct {
|
||||
ctx sdk.Context
|
||||
keeper keeper.Keeper
|
||||
storeKey *storetypes.KVStoreKey
|
||||
}
|
||||
|
||||
func NewTestKeeper() testKeeper {
|
||||
storeKey := sdk.NewKVStoreKey(types.ModuleName)
|
||||
// Not required by module, but needs to be non-nil for context
|
||||
tKey := sdk.NewTransientStoreKey("transient_test")
|
||||
ctx := testutil.DefaultContext(storeKey, tKey)
|
||||
|
||||
tApp := app.NewTestApp()
|
||||
cdc := tApp.AppCodec()
|
||||
k := keeper.NewKeeper(cdc, storeKey)
|
||||
|
||||
return testKeeper{
|
||||
ctx: ctx,
|
||||
keeper: k,
|
||||
storeKey: storeKey,
|
||||
}
|
||||
}
|
66
x/precisebank/keeper/remainder_amount.go
Normal file
66
x/precisebank/keeper/remainder_amount.go
Normal file
@ -0,0 +1,66 @@
|
||||
package keeper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
)
|
||||
|
||||
// GetRemainderAmount returns the internal remainder amount.
|
||||
func (k *Keeper) GetRemainderAmount(
|
||||
ctx sdk.Context,
|
||||
) sdkmath.Int {
|
||||
store := ctx.KVStore(k.storeKey)
|
||||
|
||||
bz := store.Get(types.RemainderBalanceKey)
|
||||
if bz == nil {
|
||||
return sdkmath.ZeroInt()
|
||||
}
|
||||
|
||||
var bal sdkmath.Int
|
||||
if err := bal.Unmarshal(bz); err != nil {
|
||||
panic(fmt.Errorf("failed to unmarshal remainder amount: %w", err))
|
||||
}
|
||||
|
||||
return bal
|
||||
}
|
||||
|
||||
// SetRemainderAmount sets the internal remainder amount.
|
||||
func (k *Keeper) SetRemainderAmount(
|
||||
ctx sdk.Context,
|
||||
amount sdkmath.Int,
|
||||
) {
|
||||
// Prevent storing zero amounts. In practice, the remainder amount should
|
||||
// only be non-zero during transactions as mint and burns should net zero
|
||||
// due to only being used for EVM transfers.
|
||||
if amount.IsZero() {
|
||||
k.DeleteRemainderAmount(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the remainder is valid before setting it. Follows the same
|
||||
// validation as FractionalBalance with the same value range.
|
||||
if err := types.NewFractionalAmountFromInt(amount).Validate(); err != nil {
|
||||
panic(fmt.Errorf("remainder amount is invalid: %w", err))
|
||||
}
|
||||
|
||||
store := ctx.KVStore(k.storeKey)
|
||||
|
||||
amountBytes, err := amount.Marshal()
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal remainder amount: %w", err))
|
||||
}
|
||||
|
||||
store.Set(types.RemainderBalanceKey, amountBytes)
|
||||
}
|
||||
|
||||
// DeleteRemainderAmount deletes the internal remainder amount.
|
||||
func (k *Keeper) DeleteRemainderAmount(
|
||||
ctx sdk.Context,
|
||||
) {
|
||||
store := ctx.KVStore(k.storeKey)
|
||||
store.Delete(types.RemainderBalanceKey)
|
||||
}
|
71
x/precisebank/keeper/remainder_amount_test.go
Normal file
71
x/precisebank/keeper/remainder_amount_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package keeper_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetSetRemainderAmount(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k, storeKey := tk.ctx, tk.keeper, tk.storeKey
|
||||
|
||||
// Set amount
|
||||
k.SetRemainderAmount(ctx, sdkmath.NewInt(100))
|
||||
|
||||
amt := k.GetRemainderAmount(ctx)
|
||||
require.Equal(t, sdkmath.NewInt(100), amt)
|
||||
|
||||
// Set zero balance
|
||||
k.SetRemainderAmount(ctx, sdkmath.ZeroInt())
|
||||
|
||||
amt = k.GetRemainderAmount(ctx)
|
||||
require.Equal(t, sdkmath.ZeroInt(), amt)
|
||||
|
||||
// Get directly from store to make sure it was actually deleted
|
||||
store := ctx.KVStore(storeKey)
|
||||
bz := store.Get(types.RemainderBalanceKey)
|
||||
require.Nil(t, bz)
|
||||
}
|
||||
|
||||
func TestInvalidRemainderAmount(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k := tk.ctx, tk.keeper
|
||||
|
||||
// Set negative amount
|
||||
require.PanicsWithError(t, "remainder amount is invalid: non-positive amount -1", func() {
|
||||
k.SetRemainderAmount(ctx, sdkmath.NewInt(-1))
|
||||
})
|
||||
|
||||
// Set amount over max
|
||||
require.PanicsWithError(t, "remainder amount is invalid: amount 1000000000000 exceeds max of 999999999999", func() {
|
||||
k.SetRemainderAmount(ctx, types.ConversionFactor())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteRemainderAmount(t *testing.T) {
|
||||
tk := NewTestKeeper()
|
||||
ctx, k, storeKey := tk.ctx, tk.keeper, tk.storeKey
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
k.DeleteRemainderAmount(ctx)
|
||||
})
|
||||
|
||||
// Set amount
|
||||
k.SetRemainderAmount(ctx, sdkmath.NewInt(100))
|
||||
|
||||
amt := k.GetRemainderAmount(ctx)
|
||||
require.Equal(t, sdkmath.NewInt(100), amt)
|
||||
|
||||
// Delete amount
|
||||
k.DeleteRemainderAmount(ctx)
|
||||
|
||||
amt = k.GetRemainderAmount(ctx)
|
||||
require.Equal(t, sdkmath.ZeroInt(), amt)
|
||||
|
||||
store := ctx.KVStore(storeKey)
|
||||
bz := store.Get(types.RemainderBalanceKey)
|
||||
require.Nil(t, bz)
|
||||
}
|
41
x/precisebank/types/fractional_amount.go
Normal file
41
x/precisebank/types/fractional_amount.go
Normal file
@ -0,0 +1,41 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
)
|
||||
|
||||
// FractionalAmount represents a fractional amount between the valid range of 1
|
||||
// and maxFractionalAmount. This wraps an sdkmath.Int to provide additional
|
||||
// validation methods so it can be re-used in multiple places.
|
||||
type FractionalAmount struct {
|
||||
sdkmath.Int
|
||||
}
|
||||
|
||||
// NewFractionalAmountFromInt creates a new FractionalAmount from an sdkmath.Int.
|
||||
func NewFractionalAmountFromInt(i sdkmath.Int) FractionalAmount {
|
||||
return FractionalAmount{i}
|
||||
}
|
||||
|
||||
// NewFractionalAmount creates a new FractionalAmount from an int64.
|
||||
func NewFractionalAmount(i int64) FractionalAmount {
|
||||
return FractionalAmount{sdkmath.NewInt(i)}
|
||||
}
|
||||
|
||||
// Validate checks if the FractionalAmount is valid.
|
||||
func (f FractionalAmount) Validate() error {
|
||||
if f.IsNil() {
|
||||
return fmt.Errorf("nil amount")
|
||||
}
|
||||
|
||||
if !f.IsPositive() {
|
||||
return fmt.Errorf("non-positive amount %v", f)
|
||||
}
|
||||
|
||||
if f.GT(maxFractionalAmount) {
|
||||
return fmt.Errorf("amount %v exceeds max of %v", f, maxFractionalAmount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,8 +1,6 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
|
||||
sdkmath "cosmossdk.io/math"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
)
|
||||
@ -40,17 +38,6 @@ func (fb FractionalBalance) Validate() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if fb.Amount.IsNil() {
|
||||
return fmt.Errorf("nil amount")
|
||||
}
|
||||
|
||||
if !fb.Amount.IsPositive() {
|
||||
return fmt.Errorf("non-positive amount %v", fb.Amount)
|
||||
}
|
||||
|
||||
if fb.Amount.GT(maxFractionalAmount) {
|
||||
return fmt.Errorf("amount %v exceeds max of %v", fb.Amount, maxFractionalAmount)
|
||||
}
|
||||
|
||||
return nil
|
||||
// Validate the amount with the FractionalAmount wrapper
|
||||
return NewFractionalAmountFromInt(fb.Amount).Validate()
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
package types
|
||||
|
||||
import sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
|
||||
const (
|
||||
// ModuleName name that will be used throughout the module
|
||||
ModuleName = "precisebank"
|
||||
@ -9,3 +11,18 @@ const (
|
||||
// RouterKey Top level router key
|
||||
RouterKey = ModuleName
|
||||
)
|
||||
|
||||
// key prefixes for store
|
||||
var (
|
||||
FractionalBalancePrefix = []byte{0x01} // address -> fractional balance
|
||||
)
|
||||
|
||||
// Keys for store that are not prefixed
|
||||
var (
|
||||
RemainderBalanceKey = []byte{0x02} // fractional balance remainder
|
||||
)
|
||||
|
||||
// FractionalBalanceKey returns a key from an address
|
||||
func FractionalBalanceKey(address sdk.AccAddress) []byte {
|
||||
return address.Bytes()
|
||||
}
|
||||
|
17
x/precisebank/types/keys_test.go
Normal file
17
x/precisebank/types/keys_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
"github.com/kava-labs/kava/x/precisebank/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFractionalBalanceKey(t *testing.T) {
|
||||
addr := sdk.AccAddress([]byte("test-address"))
|
||||
|
||||
key := types.FractionalBalanceKey(addr)
|
||||
require.Equal(t, addr.Bytes(), key)
|
||||
require.Equal(t, addr, sdk.AccAddress(key), "key should be able to be converted back to address")
|
||||
}
|
Loading…
Reference in New Issue
Block a user