From b42c82d59c29145217705e991006e6a729b8e15c Mon Sep 17 00:00:00 2001
From: Solovyov1796 <xinyu@0g.ai>
Date: Wed, 12 Mar 2025 21:09:37 +0800
Subject: [PATCH] add lock and support nonce rewind

---
 app/priority_nonce.go | 103 +++++++++++++++++++++++++++++-------------
 1 file changed, 72 insertions(+), 31 deletions(-)

diff --git a/app/priority_nonce.go b/app/priority_nonce.go
index e2874e97..6a9be970 100644
--- a/app/priority_nonce.go
+++ b/app/priority_nonce.go
@@ -26,6 +26,7 @@ var (
 	errMempoolTxGasPriceTooLow = errors.New("gas price is too low")
 	errMempoolTooManyTxs       = errors.New("tx sender has too many txs in mempool")
 	errMempoolIsFull           = errors.New("mempool is full")
+	errTxInMempool             = errors.New("tx already in mempool")
 )
 
 // PriorityNonceMempool is a mempool implementation that stores txs
@@ -36,6 +37,7 @@ var (
 // priority to other sender txs and must be partially ordered by both sender-nonce
 // and priority.
 type PriorityNonceMempool struct {
+	mtx            sync.Mutex
 	priorityIndex  *skiplist.SkipList
 	priorityCounts map[int64]int
 	senderIndices  map[string]*skiplist.SkipList
@@ -48,7 +50,7 @@ type PriorityNonceMempool struct {
 	counterBySender map[string]int
 	txRecord        map[txMeta]struct{}
 
-	txReplacedCallback func(ctx context.Context, oldTx, newTx sdk.Tx)
+	txReplacedCallback func(ctx context.Context, oldTx, newTx *TxInfo)
 }
 
 type PriorityNonceIterator struct {
@@ -135,7 +137,7 @@ func PriorityNonceWithMaxTx(maxTx int) PriorityNonceMempoolOption {
 	}
 }
 
-func PriorityNonceWithTxReplacedCallback(cb func(ctx context.Context, oldTx, newTx sdk.Tx)) PriorityNonceMempoolOption {
+func PriorityNonceWithTxReplacedCallback(cb func(ctx context.Context, oldTx, newTx *TxInfo)) PriorityNonceMempoolOption {
 	return func(mp *PriorityNonceMempool) {
 		mp.txReplacedCallback = cb
 	}
@@ -188,6 +190,9 @@ func (mp *PriorityNonceMempool) NextSenderTx(sender string) sdk.Tx {
 // Inserting a duplicate tx with a different priority overwrites the existing tx,
 // changing the total order of the mempool.
 func (mp *PriorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
+	mp.mtx.Lock()
+	defer mp.mtx.Unlock()
+
 	// if mp.maxTx > 0 && mp.CountTx() >= mp.maxTx {
 	// 	return mempool.ErrMempoolTxMaxCapacity
 	// } else
@@ -197,27 +202,28 @@ func (mp *PriorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
 
 	sdkContext := sdk.UnwrapSDKContext(ctx)
 	priority := sdkContext.Priority()
+
 	txInfo, err := extractTxInfo(tx)
 	if err != nil {
 		return err
 	}
 
-	if !mp.canInsert(txInfo.sender) {
-		return errors.Wrapf(errMempoolTooManyTxs, "sender %s has too many txs in mempool", txInfo.sender)
+	if !mp.canInsert(txInfo.Sender) {
+		return errors.Wrapf(errMempoolTooManyTxs, "[%d@%s]sender has too many txs in mempool", txInfo.Nonce, txInfo.Sender)
 	}
 
 	// init sender index if not exists
-	senderIndex, ok := mp.senderIndices[txInfo.sender]
+	senderIndex, ok := mp.senderIndices[txInfo.Sender]
 	if !ok {
 		senderIndex = skiplist.New(skiplist.LessThanFunc(func(a, b any) int {
 			return skiplist.Uint64.Compare(b.(txMeta).nonce, a.(txMeta).nonce)
 		}))
 
 		// initialize sender index if not found
-		mp.senderIndices[txInfo.sender] = senderIndex
+		mp.senderIndices[txInfo.Sender] = senderIndex
 	}
 
-	newKey := txMeta{nonce: txInfo.nonce, priority: priority, sender: txInfo.sender}
+	newKey := txMeta{nonce: txInfo.Nonce, priority: priority, sender: txInfo.Sender}
 
 	// Since mp.priorityIndex is scored by priority, then sender, then nonce, a
 	// changed priority will create a new key, so we must remove the old key and
@@ -227,17 +233,20 @@ func (mp *PriorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
 	// This O(log n) remove operation is rare and only happens when a tx's priority
 	// changes.
 
-	sk := txMeta{nonce: txInfo.nonce, sender: txInfo.sender}
+	sk := txMeta{nonce: txInfo.Nonce, sender: txInfo.Sender}
 	if oldScore, txExists := mp.scores[sk]; txExists {
-		oldTx := senderIndex.Get(newKey).Value.(sdk.Tx)
-		return mp.doTxReplace(ctx, newKey, oldScore, oldTx, tx)
+		if oldScore.priority < priority {
+			oldTx := senderIndex.Get(newKey).Value.(sdk.Tx)
+			return mp.doTxReplace(ctx, newKey, oldScore, oldTx, tx)
+		}
+		return errors.Wrapf(errTxInMempool, "[%d@%s] tx already in mempool", txInfo.Nonce, txInfo.Sender)
 	} else {
-		mempoolSize := mp.CountTx()
+		mempoolSize := mp.priorityIndex.Len()
 		if mempoolSize >= mp.maxTx {
-			lowestPriority := mp.GetLowestPriority()
+			lowestPriority := mp.getLowestPriority()
 			// find one to replace
 			if lowestPriority > 0 && priority <= lowestPriority {
-				return errors.Wrapf(errMempoolTxGasPriceTooLow, "tx with priority %d is too low, current lowest priority is %d", priority, lowestPriority)
+				return errors.Wrapf(errMempoolTxGasPriceTooLow, "[%d@%s]tx with priority %d is too low, current lowest priority is %d", newKey.nonce, newKey.sender, priority, lowestPriority)
 			}
 
 			var maxIndexSize int
@@ -245,7 +254,7 @@ func (mp *PriorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
 			var selectedElement *skiplist.Element
 			for sender, index := range mp.senderIndices {
 				indexSize := index.Len()
-				if sender == txInfo.sender {
+				if sender == txInfo.Sender {
 					continue
 				}
 
@@ -275,12 +284,16 @@ func (mp *PriorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
 				mp.doInsert(newKey, tx, true)
 
 				if mp.txReplacedCallback != nil && replacedTx != nil {
-					mp.txReplacedCallback(ctx, replacedTx, tx)
+					sdkContext.Logger().Debug("txn replaced caused by full of mempool", "old", fmt.Sprintf("%d@%s", key.nonce, key.sender), "new", fmt.Sprintf("%d@%s", newKey.nonce, newKey.sender), "mempoolSize", mempoolSize)
+					mp.txReplacedCallback(ctx,
+						&TxInfo{Sender: key.sender, Nonce: key.nonce, Tx: replacedTx},
+						&TxInfo{Sender: newKey.sender, Nonce: newKey.nonce, Tx: tx},
+					)
 				}
 			} else {
 				// not found any index more than 1 except sender's index
 				// We do not replace the sender's only tx in the mempool
-				return errMempoolIsFull
+				return errors.Wrapf(errMempoolIsFull, "%d@%s with priority%d", newKey.nonce, newKey.sender, newKey.priority)
 			}
 		} else {
 			mp.doInsert(newKey, tx, true)
@@ -315,13 +328,13 @@ func (mp *PriorityNonceMempool) doRemove(oldKey txMeta, decrCnt bool) (sdk.Tx, e
 	scoreKey := txMeta{nonce: oldKey.nonce, sender: oldKey.sender}
 	score, ok := mp.scores[scoreKey]
 	if !ok {
-		return nil, mempool.ErrTxNotFound
+		return nil, errors.Wrapf(mempool.ErrTxNotFound, "%d@%s not found", oldKey.nonce, oldKey.sender)
 	}
 	tk := txMeta{nonce: oldKey.nonce, priority: score.priority, sender: oldKey.sender, weight: score.weight}
 
 	senderTxs, ok := mp.senderIndices[oldKey.sender]
 	if !ok {
-		return nil, fmt.Errorf("sender %s not found", oldKey.sender)
+		return nil, fmt.Errorf("%d@%s not found", oldKey.nonce, oldKey.sender)
 	}
 
 	mp.priorityIndex.Remove(tk)
@@ -363,7 +376,12 @@ func (mp *PriorityNonceMempool) doTxReplace(ctx context.Context, newMate, oldMat
 	mp.doInsert(newMate, newTx, false)
 
 	if mp.txReplacedCallback != nil && replacedTx != nil {
-		mp.txReplacedCallback(ctx, replacedTx, newTx)
+		sdkContext := sdk.UnwrapSDKContext(ctx)
+		sdkContext.Logger().Debug("txn update", "txn", fmt.Sprintf("%d@%s", newMate.nonce, newMate.sender), "oldPriority", oldMate.priority, "newPriority", newMate.priority)
+		mp.txReplacedCallback(ctx,
+			&TxInfo{Sender: newMate.sender, Nonce: newMate.nonce, Tx: replacedTx},
+			&TxInfo{Sender: newMate.sender, Nonce: newMate.nonce, Tx: newTx},
+		)
 	}
 
 	return nil
@@ -445,7 +463,24 @@ func (i *PriorityNonceIterator) Tx() sdk.Tx {
 //
 // NOTE: It is not safe to use this iterator while removing transactions from
 // the underlying mempool.
-func (mp *PriorityNonceMempool) Select(_ context.Context, _ [][]byte) mempool.Iterator {
+func (mp *PriorityNonceMempool) Select(ctx context.Context, txs [][]byte) mempool.Iterator {
+	mp.mtx.Lock()
+	defer mp.mtx.Unlock()
+
+	return mp.doSelect(ctx, txs)
+}
+
+func (mp *PriorityNonceMempool) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) {
+	mp.mtx.Lock()
+	defer mp.mtx.Unlock()
+
+	iter := mp.doSelect(ctx, txs)
+	for iter != nil && callback(iter.Tx()) {
+		iter = iter.Next()
+	}
+}
+
+func (mp *PriorityNonceMempool) doSelect(_ context.Context, _ [][]byte) mempool.Iterator {
 	if mp.priorityIndex.Len() == 0 {
 		return nil
 	}
@@ -514,29 +549,34 @@ func senderWeight(senderCursor *skiplist.Element) int64 {
 
 // CountTx returns the number of transactions in the mempool.
 func (mp *PriorityNonceMempool) CountTx() int {
+	mp.mtx.Lock()
+	defer mp.mtx.Unlock()
 	return mp.priorityIndex.Len()
 }
 
 // Remove removes a transaction from the mempool in O(log n) time, returning an
 // error if unsuccessful.
 func (mp *PriorityNonceMempool) Remove(tx sdk.Tx) error {
+	mp.mtx.Lock()
+	defer mp.mtx.Unlock()
+
 	txInfo, err := extractTxInfo(tx)
 	if err != nil {
 		return err
 	}
 
-	mp.decrSenderTxCnt(txInfo.sender, txInfo.nonce)
+	mp.decrSenderTxCnt(txInfo.Sender, txInfo.Nonce)
 
-	scoreKey := txMeta{nonce: txInfo.nonce, sender: txInfo.sender}
+	scoreKey := txMeta{nonce: txInfo.Nonce, sender: txInfo.Sender}
 	score, ok := mp.scores[scoreKey]
 	if !ok {
 		return mempool.ErrTxNotFound
 	}
-	tk := txMeta{nonce: txInfo.nonce, priority: score.priority, sender: txInfo.sender, weight: score.weight}
+	tk := txMeta{nonce: txInfo.Nonce, priority: score.priority, sender: txInfo.Sender, weight: score.weight}
 
-	senderTxs, ok := mp.senderIndices[txInfo.sender]
+	senderTxs, ok := mp.senderIndices[txInfo.Sender]
 	if !ok {
-		return fmt.Errorf("sender %s not found", txInfo.sender)
+		return fmt.Errorf("sender %s not found", txInfo.Sender)
 	}
 
 	mp.priorityIndex.Remove(tk)
@@ -547,7 +587,7 @@ func (mp *PriorityNonceMempool) Remove(tx sdk.Tx) error {
 	return nil
 }
 
-func (mp *PriorityNonceMempool) GetLowestPriority() int64 {
+func (mp *PriorityNonceMempool) getLowestPriority() int64 {
 	if mp.priorityIndex.Len() == 0 {
 		return 0
 	}
@@ -645,12 +685,13 @@ func IsEmpty(mempool mempool.Mempool) error {
 	return nil
 }
 
-type txInfo struct {
-	sender string
-	nonce  uint64
+type TxInfo struct {
+	Sender string
+	Nonce  uint64
+	Tx     sdk.Tx
 }
 
-func extractTxInfo(tx sdk.Tx) (*txInfo, error) {
+func extractTxInfo(tx sdk.Tx) (*TxInfo, error) {
 	var sender string
 	var nonce uint64
 
@@ -682,5 +723,5 @@ func extractTxInfo(tx sdk.Tx) (*txInfo, error) {
 		nonce = sig.Sequence
 	}
 
-	return &txInfo{sender: sender, nonce: nonce}, nil
+	return &TxInfo{Sender: sender, Nonce: nonce, Tx: tx}, nil
 }