From 4442eee9f219144ae687e681c02a02406f17a6d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20Junior?= Date: Wed, 14 Dec 2022 17:09:11 -0400 Subject: [PATCH] feat(dot/state): create `Range` to traverse the blocktree and the blocks in the disk (#2990) --- dot/state/block.go | 136 +++++++++++- dot/state/block_test.go | 361 +++++++++++++++++++++++++++++++ dot/state/mocks_chaindb_test.go | 193 +++++++++++++++++ dot/state/mocks_generate_test.go | 1 + dot/state/test_helpers.go | 2 +- lib/blocktree/blocktree.go | 127 ++++++++--- lib/blocktree/blocktree_test.go | 55 +++-- lib/blocktree/node.go | 23 -- 8 files changed, 813 insertions(+), 85 deletions(-) create mode 100644 dot/state/mocks_chaindb_test.go diff --git a/dot/state/block.go b/dot/state/block.go index 4eeae7e75b..b57774f13b 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -38,6 +38,7 @@ var ( messageQueuePrefix = []byte("mqp") // messageQueuePrefix + hash -> message queue justificationPrefix = []byte("jcp") // justificationPrefix + hash -> justification + errNilBlockTree = errors.New("blocktree is nil") errNilBlockBody = errors.New("block body is nil") syncedBlocksGauge = promauto.NewGauge(prometheus.GaugeOpts{ @@ -542,10 +543,141 @@ func (bs *BlockState) GetSlotForBlock(hash common.Hash) (uint64, error) { return types.GetSlotFromHeader(header) } +var ErrEmptyHeader = errors.New("empty header") + +func (bs *BlockState) loadHeaderFromDatabase(hash common.Hash) (header *types.Header, err error) { + startHeaderData, err := bs.db.Get(headerKey(hash)) + if err != nil { + return nil, fmt.Errorf("querying database: %w", err) + } + + header = types.NewEmptyHeader() + err = scale.Unmarshal(startHeaderData, header) + if err != nil { + return nil, fmt.Errorf("unmarshaling start header: %w", err) + } + + if header.Empty() { + return nil, fmt.Errorf("%w: %s", ErrEmptyHeader, hash) + } + + return header, nil +} + +// Range returns the sub-blockchain between the starting hash and the +// ending hash using both block tree and database +func (bs *BlockState) Range(startHash, endHash common.Hash) (hashes []common.Hash, err error) { + if startHash == endHash { + hashes = []common.Hash{startHash} + return hashes, nil + } + + endHeader, err := bs.loadHeaderFromDatabase(endHash) + if errors.Is(err, chaindb.ErrKeyNotFound) || + errors.Is(err, ErrEmptyHeader) { + // end hash is not in the database so we should lookup the + // block that could be in memory and in the database as well + return bs.retrieveRange(startHash, endHash) + } else if err != nil { + return nil, fmt.Errorf("retrieving end hash from database: %w", err) + } + + // end hash was found in the database, that means all the blocks + // between start and end can be found in the database + return bs.retrieveRangeFromDatabase(startHash, endHeader) +} + +func (bs *BlockState) retrieveRange(startHash, endHash common.Hash) (hashes []common.Hash, err error) { + inMemoryHashes, err := bs.bt.Range(startHash, endHash) + if err != nil { + return nil, fmt.Errorf("retrieving range from in-memory blocktree: %w", err) + } + + firstItem := inMemoryHashes[0] + + // if the first item is equal to the startHash that means we got the range + // from the in-memory blocktree + if firstItem == startHash { + return inMemoryHashes, nil + } + + // since we got as many blocks as we could from + // the block tree but still missing blocks to + // fulfil the range we should lookup in the + // database for the remaining ones, the first item in the hashes array + // must be the block tree root that is also placed in the database + // so we will start from its parent since it is already in the array + blockTreeRootHeader, err := bs.loadHeaderFromDatabase(firstItem) + if err != nil { + return nil, fmt.Errorf("loading block tree root from database: %w", err) + } + + startingAtParentHeader, err := bs.loadHeaderFromDatabase(blockTreeRootHeader.ParentHash) + if err != nil { + return nil, fmt.Errorf("loading header of parent of the root from database: %w", err) + } + + inDatabaseHashes, err := bs.retrieveRangeFromDatabase(startHash, startingAtParentHeader) + if err != nil { + return nil, fmt.Errorf("retrieving range from database: %w", err) + } + + hashes = append(inDatabaseHashes, inMemoryHashes...) + return hashes, nil +} + +var ErrStartHashMismatch = errors.New("start hash mismatch") +var ErrStartGreaterThanEnd = errors.New("start greater than end") + +// retrieveRangeFromDatabase takes the start and the end and will retrieve all block in between +// where all blocks (start and end inclusive) are supposed to be placed at database +func (bs *BlockState) retrieveRangeFromDatabase(startHash common.Hash, + endHeader *types.Header) (hashes []common.Hash, err error) { + startHeader, err := bs.loadHeaderFromDatabase(startHash) + if err != nil { + return nil, fmt.Errorf("range start should be in database: %w", err) + } + + if startHeader.Number > endHeader.Number { + return nil, fmt.Errorf("%w", ErrStartGreaterThanEnd) + } + + // blocksInRange is the difference between the end number to start number + // but the difference doesn't include the start item so we add 1 + blocksInRange := endHeader.Number - startHeader.Number + 1 + + hashes = make([]common.Hash, blocksInRange) + + lastPosition := blocksInRange - 1 + + hashes[0] = startHash + hashes[lastPosition] = endHeader.Hash() + + inLoopHash := endHeader.ParentHash + for currentPosition := lastPosition - 1; currentPosition > 0; currentPosition-- { + hashes[currentPosition] = inLoopHash + + inLoopHeader, err := bs.loadHeaderFromDatabase(inLoopHash) + if err != nil { + return nil, fmt.Errorf("retrieving hash %s from database: %w", inLoopHash.Short(), err) + } + + inLoopHash = inLoopHeader.ParentHash + } + + // here we ensure that we finished up the loop + // with the same hash as the startHash + if inLoopHash != startHash { + return nil, fmt.Errorf("%w: expecting %s, found: %s", ErrStartHashMismatch, startHash.Short(), inLoopHash.Short()) + } + + return hashes, nil +} + // SubChain returns the sub-blockchain between the starting hash and the ending hash using the block tree func (bs *BlockState) SubChain(start, end common.Hash) ([]common.Hash, error) { if bs.bt == nil { - return nil, fmt.Errorf("blocktree is nil") + return nil, fmt.Errorf("%w", errNilBlockTree) } return bs.bt.SubBlockchain(start, end) @@ -555,7 +687,7 @@ func (bs *BlockState) SubChain(start, end common.Hash) ([]common.Hash, error) { // it returns an error if parent or child are not in the blocktree. func (bs *BlockState) IsDescendantOf(parent, child common.Hash) (bool, error) { if bs.bt == nil { - return false, fmt.Errorf("blocktree is nil") + return false, fmt.Errorf("%w", errNilBlockTree) } return bs.bt.IsDescendantOf(parent, child) diff --git a/dot/state/block_test.go b/dot/state/block_test.go index e788427957..1764183f85 100644 --- a/dot/state/block_test.go +++ b/dot/state/block_test.go @@ -4,10 +4,13 @@ package state import ( + "errors" "testing" "time" + "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" @@ -599,3 +602,361 @@ func TestNumberIsFinalised(t *testing.T) { require.NoError(t, err) require.False(t, fin) } + +func TestRange(t *testing.T) { + t.Parallel() + + loadHeaderFromDiskErr := errors.New("[mocked] cannot read, database closed ex") + testcases := map[string]struct { + blocksToCreate int + blocksToPersistAtDisk int + + newBlockState func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState + wantErr error + stringErr string + + expectedHashes func(hashesCreated []common.Hash) (expected []common.Hash) + executeRangeCall func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) + }{ + "all_blocks_stored_in_disk": { + blocksToCreate: 128, + blocksToPersistAtDisk: 128, + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return hashesCreated + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[0] + endHash := hashesCreated[len(hashesCreated)-1] + + return blockState.Range(startHash, endHash) + }, + }, + + "all_blocks_persisted_in_blocktree": { + blocksToCreate: 128, + blocksToPersistAtDisk: 0, + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return hashesCreated + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[0] + endHash := hashesCreated[len(hashesCreated)-1] + + return blockState.Range(startHash, endHash) + }, + }, + + "half_blocks_placed_in_blocktree_half_stored_in_disk": { + blocksToCreate: 128, + blocksToPersistAtDisk: 64, + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return hashesCreated + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[0] + endHash := hashesCreated[len(hashesCreated)-1] + + return blockState.Range(startHash, endHash) + }, + }, + + "error_while_loading_header_from_disk": { + blocksToCreate: 2, + blocksToPersistAtDisk: 0, + wantErr: loadHeaderFromDiskErr, + stringErr: "retrieving end hash from database: " + + "querying database: [mocked] cannot read, database closed ex", + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + + mockedDb := NewMockDatabase(ctrl) + // cannot assert the exact hash type since the block header + // hash is generate by the running test case + mockedDb.EXPECT().Get(gomock.AssignableToTypeOf([]byte{})). + Return(nil, loadHeaderFromDiskErr) + blockState.db = mockedDb + + require.NoError(t, err) + return blockState + }, + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return nil + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[0] + endHash := hashesCreated[len(hashesCreated)-1] + + return blockState.Range(startHash, endHash) + }, + }, + + "using_same_hash_as_parameters": { + blocksToCreate: 128, + blocksToPersistAtDisk: 0, + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return []common.Hash{hashesCreated[0]} + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[0] + endHash := hashesCreated[0] + + return blockState.Range(startHash, endHash) + }, + }, + + "start_hash_greater_than_end_hash_in_database": { + blocksToCreate: 128, + blocksToPersistAtDisk: 128, + wantErr: ErrStartGreaterThanEnd, + stringErr: "start greater than end", + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return nil + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[10] + endHash := hashesCreated[0] + + return blockState.Range(startHash, endHash) + }, + }, + + "start_hash_greater_than_end_hash_in_memory": { + blocksToCreate: 128, + blocksToPersistAtDisk: 0, + wantErr: blocktree.ErrStartGreaterThanEnd, + stringErr: "retrieving range from in-memory blocktree: " + + "getting blocks in range: start greater than end", + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return nil + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[10] + endHash := hashesCreated[0] + + return blockState.Range(startHash, endHash) + }, + }, + + "start_hash_in_memory_while_end_hash_in_database": { + blocksToCreate: 128, + blocksToPersistAtDisk: 64, + wantErr: chaindb.ErrKeyNotFound, + stringErr: "range start should be in database: " + + "querying database: Key not found", + newBlockState: func(t *testing.T, ctrl *gomock.Controller, + genesisHeader *types.Header) *BlockState { + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + return blockState + }, + + // execute the Range call. All the values returned must + // match the hashes we previsouly created + expectedHashes: func(hashesCreated []common.Hash) (expected []common.Hash) { + return nil + }, + executeRangeCall: func(blockState *BlockState, + hashesCreated []common.Hash) (retrievedHashes []common.Hash, err error) { + startHash := hashesCreated[len(hashesCreated)-1] + // since we finalized 64 of 128 blocks the end hash is one of + // those blocks persisted at database, while start hash is + // one of those blocks that keeps in memory + endHash := hashesCreated[0] + + return blockState.Range(startHash, endHash) + }, + }, + } + + for tname, tt := range testcases { + tt := tt + + t.Run(tname, func(t *testing.T) { + t.Parallel() + + require.LessOrEqualf(t, tt.blocksToPersistAtDisk, tt.blocksToCreate, + "blocksToPersistAtDisk should be lower or equal blocksToCreate") + + ctrl := gomock.NewController(t) + genesisHeader := &types.Header{ + Number: 0, + StateRoot: trie.EmptyHash, + Digest: types.NewDigest(), + } + + blockState := tt.newBlockState(t, ctrl, genesisHeader) + + testBlockBody := *types.NewBody([]types.Extrinsic{[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}) + hashesCreated := make([]common.Hash, 0, tt.blocksToCreate) + previousHeaderHash := genesisHeader.Hash() + for blockNumber := 1; blockNumber <= tt.blocksToCreate; blockNumber++ { + currentHeader := &types.Header{ + Number: uint(blockNumber), + Digest: createPrimaryBABEDigest(t), + ParentHash: previousHeaderHash, + } + + block := &types.Block{ + Header: *currentHeader, + Body: testBlockBody, + } + + err := blockState.AddBlock(block) + require.NoError(t, err) + + hashesCreated = append(hashesCreated, currentHeader.Hash()) + previousHeaderHash = currentHeader.Hash() + } + + if tt.blocksToPersistAtDisk > 0 { + hashIndexToSetAsFinalized := tt.blocksToPersistAtDisk - 1 + selectedHash := hashesCreated[hashIndexToSetAsFinalized] + + err := blockState.SetFinalisedHash(selectedHash, 0, 0) + require.NoError(t, err) + } + + expectedHashes := tt.expectedHashes(hashesCreated) + retrievedHashes, err := tt.executeRangeCall(blockState, hashesCreated) + require.ErrorIs(t, err, tt.wantErr) + if tt.stringErr != "" { + require.EqualError(t, err, tt.stringErr) + } + + require.Equal(t, expectedHashes, retrievedHashes) + }) + } +} + +func Test_loadHeaderFromDisk_WithGenesisBlock(t *testing.T) { + ctrl := gomock.NewController(t) + + telemetryMock := NewMockClient(ctrl) + telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + db := NewInMemoryDB(t) + + genesisHeader := &types.Header{ + Number: 0, + StateRoot: trie.EmptyHash, + Digest: types.NewDigest(), + } + + blockState, err := NewBlockStateFromGenesis(db, newTriesEmpty(), genesisHeader, telemetryMock) + require.NoError(t, err) + + header, err := blockState.loadHeaderFromDatabase(genesisHeader.Hash()) + require.NoError(t, err) + require.Equal(t, genesisHeader.Hash(), header.Hash()) +} diff --git a/dot/state/mocks_chaindb_test.go b/dot/state/mocks_chaindb_test.go new file mode 100644 index 0000000000..a25cdb39f6 --- /dev/null +++ b/dot/state/mocks_chaindb_test.go @@ -0,0 +1,193 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/chaindb (interfaces: Database) + +// Package state is a generated GoMock package. +package state + +import ( + context "context" + reflect "reflect" + + chaindb "github.com/ChainSafe/chaindb" + pb "github.com/dgraph-io/badger/v2/pb" + gomock "github.com/golang/mock/gomock" +) + +// MockDatabase is a mock of Database interface. +type MockDatabase struct { + ctrl *gomock.Controller + recorder *MockDatabaseMockRecorder +} + +// MockDatabaseMockRecorder is the mock recorder for MockDatabase. +type MockDatabaseMockRecorder struct { + mock *MockDatabase +} + +// NewMockDatabase creates a new mock instance. +func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase { + mock := &MockDatabase{ctrl: ctrl} + mock.recorder = &MockDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { + return m.recorder +} + +// ClearAll mocks base method. +func (m *MockDatabase) ClearAll() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClearAll") + ret0, _ := ret[0].(error) + return ret0 +} + +// ClearAll indicates an expected call of ClearAll. +func (mr *MockDatabaseMockRecorder) ClearAll() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearAll", reflect.TypeOf((*MockDatabase)(nil).ClearAll)) +} + +// Close mocks base method. +func (m *MockDatabase) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockDatabaseMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatabase)(nil).Close)) +} + +// Del mocks base method. +func (m *MockDatabase) Del(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Del", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Del indicates an expected call of Del. +func (mr *MockDatabaseMockRecorder) Del(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockDatabase)(nil).Del), arg0) +} + +// Flush mocks base method. +func (m *MockDatabase) Flush() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Flush") + ret0, _ := ret[0].(error) + return ret0 +} + +// Flush indicates an expected call of Flush. +func (mr *MockDatabaseMockRecorder) Flush() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockDatabase)(nil).Flush)) +} + +// Get mocks base method. +func (m *MockDatabase) Get(arg0 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockDatabaseMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDatabase)(nil).Get), arg0) +} + +// Has mocks base method. +func (m *MockDatabase) Has(arg0 []byte) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Has", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Has indicates an expected call of Has. +func (mr *MockDatabaseMockRecorder) Has(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockDatabase)(nil).Has), arg0) +} + +// NewBatch mocks base method. +func (m *MockDatabase) NewBatch() chaindb.Batch { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewBatch") + ret0, _ := ret[0].(chaindb.Batch) + return ret0 +} + +// NewBatch indicates an expected call of NewBatch. +func (mr *MockDatabaseMockRecorder) NewBatch() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewBatch", reflect.TypeOf((*MockDatabase)(nil).NewBatch)) +} + +// NewIterator mocks base method. +func (m *MockDatabase) NewIterator() chaindb.Iterator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewIterator") + ret0, _ := ret[0].(chaindb.Iterator) + return ret0 +} + +// NewIterator indicates an expected call of NewIterator. +func (mr *MockDatabaseMockRecorder) NewIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIterator", reflect.TypeOf((*MockDatabase)(nil).NewIterator)) +} + +// Path mocks base method. +func (m *MockDatabase) Path() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Path") + ret0, _ := ret[0].(string) + return ret0 +} + +// Path indicates an expected call of Path. +func (mr *MockDatabaseMockRecorder) Path() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Path", reflect.TypeOf((*MockDatabase)(nil).Path)) +} + +// Put mocks base method. +func (m *MockDatabase) Put(arg0, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Put indicates an expected call of Put. +func (mr *MockDatabaseMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockDatabase)(nil).Put), arg0, arg1) +} + +// Subscribe mocks base method. +func (m *MockDatabase) Subscribe(arg0 context.Context, arg1 func(*pb.KVList) error, arg2 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Subscribe", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Subscribe indicates an expected call of Subscribe. +func (mr *MockDatabaseMockRecorder) Subscribe(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockDatabase)(nil).Subscribe), arg0, arg1, arg2) +} diff --git a/dot/state/mocks_generate_test.go b/dot/state/mocks_generate_test.go index bde3f7cf99..c4b130a82f 100644 --- a/dot/state/mocks_generate_test.go +++ b/dot/state/mocks_generate_test.go @@ -4,3 +4,4 @@ package state //go:generate mockgen -destination=mock_telemetry_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/dot/telemetry Client +//go:generate mockgen -destination=mocks_chaindb_test.go -package $GOPACKAGE github.com/ChainSafe/chaindb Database diff --git a/dot/state/test_helpers.go b/dot/state/test_helpers.go index 4066343e05..6dfed8f905 100644 --- a/dot/state/test_helpers.go +++ b/dot/state/test_helpers.go @@ -36,7 +36,7 @@ func NewInMemoryDB(t *testing.T) chaindb.Database { return db } -func createPrimaryBABEDigest(t *testing.T) scale.VaryingDataTypeSlice { +func createPrimaryBABEDigest(t testing.TB) scale.VaryingDataTypeSlice { babeDigest := types.NewBabeDigest() err := babeDigest.Set(types.BabePrimaryPreDigest{AuthorityIndex: 0}) require.NoError(t, err) diff --git a/lib/blocktree/blocktree.go b/lib/blocktree/blocktree.go index c0b683f0a5..a9ae5443a2 100644 --- a/lib/blocktree/blocktree.go +++ b/lib/blocktree/blocktree.go @@ -130,6 +130,103 @@ func (bt *BlockTree) GetAllBlocksAtNumber(hash common.Hash) (hashes []common.Has return bt.root.getNodesWithNumber(number, hashes) } +var ErrStartGreaterThanEnd = errors.New("start greater than end") +var ErrNilBlockInRange = errors.New("nil block in range") + +// Range will return all the blocks between the start and +// end hash inclusive. +// If the end hash does not exist in the blocktree then an error +// is be returned. +// If the start hash does not exist in the blocktree +// then we will return all blocks between the end and the blocktree +// root inclusive +func (bt *BlockTree) Range(startHash common.Hash, endHash common.Hash) (hashes []common.Hash, err error) { + bt.Lock() + defer bt.Unlock() + + endNode := bt.getNode(endHash) + if endNode == nil { + return nil, fmt.Errorf("%w: %s", ErrEndNodeNotFound, endHash) + } + + // if we don't find the start hash in the blocktree + // that means it should be in the database, so we retrieve + // as many nodes as we can, in other words we get all the + // blocks from the end hash till the bt.root inclusive + startNode := bt.getNode(startHash) + if startNode == nil { + startNode = bt.root + } + + hashes, err = accumulateHashesInDescedingOrder(endNode, startNode) + if err != nil { + return nil, fmt.Errorf("getting blocks in range: %w", err) + } + + return hashes, nil +} + +// SubBlockchain returns the path from the node with Hash start to the node with Hash end +func (bt *BlockTree) SubBlockchain(startHash common.Hash, endHash common.Hash) (hashes []common.Hash, err error) { + bt.Lock() + defer bt.Unlock() + + endNode := bt.getNode(endHash) + if endNode == nil { + return nil, fmt.Errorf("%w: %s", ErrEndNodeNotFound, endHash) + } + + // if we don't find the start hash in the blocktree + // that means it should be in the database, so we retrieve + // as many nodes as we can, in other words we get all the + // blocks from the end hash till the bt.root inclusive + startNode := bt.getNode(startHash) + if startNode == nil { + return nil, fmt.Errorf("%w: %s", ErrStartNodeNotFound, endHash) + } + + if startNode.number > endNode.number { + return nil, fmt.Errorf("%w", ErrStartGreaterThanEnd) + } + + hashes, err = accumulateHashesInDescedingOrder(endNode, startNode) + if err != nil { + return nil, fmt.Errorf("getting blocks in range: %w", err) + } + + return hashes, nil +} + +func accumulateHashesInDescedingOrder(endNode, startNode *node) ( + hashes []common.Hash, err error) { + + if startNode.number > endNode.number { + return nil, fmt.Errorf("%w", ErrStartGreaterThanEnd) + } + + // blocksInRange is the difference between the end number to start number + // but the difference don't includes the start item that is why we add 1 + blocksInRange := endNode.number - startNode.number + 1 + hashes = make([]common.Hash, blocksInRange) + + lastPosition := blocksInRange - 1 + hashes[0] = startNode.hash + + for position := lastPosition; position > 0; position-- { + currentNodeHash := endNode.hash + hashes[position] = currentNodeHash + + endNode = endNode.parent + + if endNode == nil { + return nil, fmt.Errorf("%w: missing parent of %s", + ErrNilBlockInRange, currentNodeHash) + } + } + + return hashes, nil +} + // getNode finds and returns a node based on its Hash. Returns nil if not found. func (bt *BlockTree) getNode(h Hash) (ret *node) { if bt.root.hash == h { @@ -210,36 +307,6 @@ func (bt *BlockTree) String() string { return fmt.Sprintf("%s\n%s\n", metadata, tree.Print()) } -// subChain returns the path from the node with Hash start to the node with Hash end -func (bt *BlockTree) subChain(start, end Hash) ([]*node, error) { - sn := bt.getNode(start) - if sn == nil { - return nil, fmt.Errorf("%w: %s", ErrStartNodeNotFound, start) - } - en := bt.getNode(end) - if en == nil { - return nil, fmt.Errorf("%w: %s", ErrEndNodeNotFound, end) - } - return sn.subChain(en) -} - -// SubBlockchain returns the path from the node with Hash start to the node with Hash end -func (bt *BlockTree) SubBlockchain(start, end Hash) ([]Hash, error) { - bt.RLock() - defer bt.RUnlock() - - sc, err := bt.subChain(start, end) - if err != nil { - return nil, err - } - var bc []Hash - for _, node := range sc { - bc = append(bc, node.hash) - } - return bc, nil - -} - // best returns the best node in the block tree using the fork choice rule. func (bt *BlockTree) best() *node { return bt.leaves.bestBlock() diff --git a/lib/blocktree/blocktree_test.go b/lib/blocktree/blocktree_test.go index f25be6adb0..0acf2637cf 100644 --- a/lib/blocktree/blocktree_test.go +++ b/lib/blocktree/blocktree_test.go @@ -37,7 +37,7 @@ func newBlockTreeFromNode(root *node) *BlockTree { } } -func createPrimaryBABEDigest(t *testing.T) scale.VaryingDataTypeSlice { +func createPrimaryBABEDigest(t testing.TB) scale.VaryingDataTypeSlice { babeDigest := types.NewBabeDigest() err := babeDigest.Set(types.BabePrimaryPreDigest{AuthorityIndex: 0}) require.NoError(t, err) @@ -116,7 +116,7 @@ func createTestBlockTree(t *testing.T, header *types.Header, number uint) (*Bloc return bt, branches } -func createFlatTree(t *testing.T, number uint) (*BlockTree, []common.Hash) { +func createFlatTree(t testing.TB, number uint) (*BlockTree, []common.Hash) { rootHeader := &types.Header{ ParentHash: zeroHash, Digest: createPrimaryBABEDigest(t), @@ -219,33 +219,6 @@ func Test_Node_isDecendantOf(t *testing.T) { } } -func Test_BlockTree_Subchain(t *testing.T) { - bt, hashes := createFlatTree(t, 4) - expectedPath := hashes[1:] - - // Insert a block to create a competing path - extraBlock := &types.Header{ - ParentHash: hashes[0], - Number: 1, - Digest: createPrimaryBABEDigest(t), - } - - extraBlock.Hash() - err := bt.AddBlock(extraBlock, time.Unix(0, 0)) - require.NotNil(t, err) - - subChain, err := bt.subChain(hashes[1], hashes[3]) - if err != nil { - t.Fatal(err) - } - - for i, n := range subChain { - if n.hash != expectedPath[i] { - t.Errorf("expected Hash: 0x%X got: 0x%X\n", expectedPath[i], n.hash) - } - } -} - func Test_BlockTree_Best_AllPrimary(t *testing.T) { arrivalTime := int64(256) var expected Hash @@ -823,3 +796,27 @@ func Test_BlockTree_best(t *testing.T) { bt.leaves.store(bt.root.children[2].hash, bt.root.children[2]) require.Equal(t, bt.root.children[2].hash, bt.BestBlockHash()) } + +func BenchmarkBlockTreeSubBlockchain(b *testing.B) { + testInputs := []struct { + input int + }{ + {input: 100}, + {input: 1000}, + {input: 10000}, + } + + for _, tt := range testInputs { + bt, expectedHashes := createFlatTree(b, uint(tt.input)) + + firstHash := expectedHashes[0] + endHash := expectedHashes[len(expectedHashes)-1] + + b.Run(fmt.Sprintf("input_len_%d", tt.input), func(b *testing.B) { + for i := 0; i < b.N; i++ { + bt.SubBlockchain(firstHash, endHash) + } + }) + } + +} diff --git a/lib/blocktree/node.go b/lib/blocktree/node.go index cacc940197..557f805066 100644 --- a/lib/blocktree/node.go +++ b/lib/blocktree/node.go @@ -78,29 +78,6 @@ func (n *node) getNodesWithNumber(number uint, hashes []common.Hash) []common.Ha return hashes } -// subChain searches for a chain with head n and descendant going from child -> parent -func (n *node) subChain(descendant *node) ([]*node, error) { - if descendant == nil { - return nil, ErrNilDescendant - } - - var path []*node - - if n.hash == descendant.hash { - path = append(path, n) - return path, nil - } - - for curr := descendant; curr != nil; curr = curr.parent { - path = append([]*node{curr}, path...) - if curr.hash == n.hash { - return path, nil - } - } - - return nil, ErrDescendantNotFound -} - // isDescendantOf traverses the tree following all possible paths until it determines if n is a descendant of parent func (n *node) isDescendantOf(parent *node) bool { if parent == nil || n == nil {