diff --git a/packages/block/db.go b/packages/block/db.go index 497df81a..7864c168 100644 --- a/packages/block/db.go +++ b/packages/block/db.go @@ -62,7 +62,7 @@ func GetRollbacksHashWithDiffArr(dbTx *sqldb.DbTransaction, bId int64) ([]byte, if err != nil { return nil, err } - arr := make([]string, 0, len(rollbackTxs)) + arr := make([]string, 0) for _, row := range rollbackTxs { data, err := json.Marshal(row) if err != nil { @@ -70,6 +70,17 @@ func GetRollbacksHashWithDiffArr(dbTx *sqldb.DbTransaction, bId int64) ([]byte, } arr = append(arr, crypto.HashHex(data)) } + spentInfos, err := sqldb.GetBlockOutputs(dbTx, bId) + if err != nil { + return nil, err + } + for _, row := range spentInfos { + data, err := json.Marshal(row) + if err != nil { + continue + } + arr = append(arr, crypto.HashHex(data)) + } sort.Strings(arr) marshal, _ := json.Marshal(arr) return crypto.Hash(marshal), nil diff --git a/packages/rollback/block.go b/packages/rollback/block.go index d47a5724..43736e27 100644 --- a/packages/rollback/block.go +++ b/packages/rollback/block.go @@ -126,6 +126,12 @@ func rollbackBlock(dbTx *sqldb.DbTransaction, block *block.Block) error { return err } } + + err := sqldb.RollbackOutputs(block.Header.BlockId, dbTx, logger) + if err != nil { + logger.WithFields(log.Fields{"type": consts.DBError, "error": err}).Error("updating outputs by block id") + return err + } return nil } diff --git a/packages/storage/sqldb/spent_info.go b/packages/storage/sqldb/spent_info.go index f0647127..7371996a 100644 --- a/packages/storage/sqldb/spent_info.go +++ b/packages/storage/sqldb/spent_info.go @@ -6,6 +6,8 @@ package sqldb import ( + "github.com/IBAX-io/go-ibax/packages/consts" + log "github.com/sirupsen/logrus" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -75,3 +77,25 @@ func GetTxOutputs(db *DbTransaction, keyIds []int64) ([]SpentInfo, error) { } return result, nil } + +func RollbackOutputs(blockID int64, db *DbTransaction, logger *log.Entry) error { + err := GetDB(db).Exec(`UPDATE spent_info SET input_tx_hash= null , input_index=0 WHERE input_tx_hash in ( SELECT output_tx_hash FROM "spent_info" WHERE block_id = ? )`, blockID).Error + if err != nil { + logger.WithFields(log.Fields{"type": consts.DBError, "error": err}).Errorf("updating input_tx_hash rollback outputs by blockID : %d", blockID) + return err + } + + err = GetDB(db).Exec(`DELETE FROM spent_info WHERE block_id = ? `, blockID).Error + if err != nil { + logger.WithFields(log.Fields{"type": consts.DBError, "error": err}).Errorf("deleting rollback outputs by blockID : %d", blockID) + return err + } + + return nil +} + +func GetBlockOutputs(dbTx *DbTransaction, blockID int64) ([]SpentInfo, error) { + var result []SpentInfo + err := GetDB(dbTx).Where("block_id = ?", blockID).Find(&result).Error + return result, err +}