diff --git a/testutil/datagen/btc_header_tree.go b/testutil/datagen/btc_header_tree.go index 4baf575a0..24f8cfe6a 100644 --- a/testutil/datagen/btc_header_tree.go +++ b/testutil/datagen/btc_header_tree.go @@ -234,6 +234,16 @@ func (t *BTCHeaderTree) RandomDescendant(node *blctypes.BTCHeaderInfo) *blctypes return descendants[idx] } +// GetHeadersMap returns a mapping between node hashes and nodes +func (t *BTCHeaderTree) GetHeadersMap() map[string]*blctypes.BTCHeaderInfo { + return t.headers +} + +// Size returns the number of nodes that are maintained +func (t *BTCHeaderTree) Size() int { + return len(t.headers) +} + // getParent returns the parent of the node, or nil if it doesn't exist func (t *BTCHeaderTree) getParent(node *blctypes.BTCHeaderInfo) *blctypes.BTCHeaderInfo { if header, ok := t.headers[node.Header.ParentHash().String()]; ok { diff --git a/types/utils.go b/types/utils.go new file mode 100644 index 000000000..ac232d60e --- /dev/null +++ b/types/utils.go @@ -0,0 +1,11 @@ +package types + +import "reflect" + +func Reverse(s interface{}) { + n := reflect.ValueOf(s).Len() + swap := reflect.Swapper(s) + for i, j := 0, n-1; i < j; i, j = i+1, j-1 { + swap(i, j) + } +} diff --git a/x/btclightclient/keeper/grpc_query.go b/x/btclightclient/keeper/grpc_query.go index 2bc3422ae..51576136b 100644 --- a/x/btclightclient/keeper/grpc_query.go +++ b/x/btclightclient/keeper/grpc_query.go @@ -4,7 +4,6 @@ import ( "context" bbl "github.com/babylonchain/babylon/types" "github.com/babylonchain/babylon/x/btclightclient/types" - "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" @@ -38,7 +37,7 @@ func (k Keeper) Hashes(ctx context.Context, req *types.QueryHashesRequest) (*typ } } - store := prefix.NewStore(k.headersState(sdkCtx).hashToHeight, types.HashToHeightPrefix) + store := k.headersState(sdkCtx).hashToHeight pageRes, err := query.FilteredPaginate(store, req.Pagination, func(key []byte, _ []byte, accumulate bool) (bool, error) { if accumulate { hashes = append(hashes, key) @@ -72,47 +71,80 @@ func (k Keeper) MainChain(ctx context.Context, req *types.QueryMainChainRequest) if req.Pagination == nil { req.Pagination = &query.PageRequest{} } - // If a starting key has not been set, then the first header is the tip - prevHeader := k.headersState(sdkCtx).GetTip() - // Otherwise, retrieve the header from the key + + if req.Pagination.Limit == 0 { + req.Pagination.Limit = query.DefaultLimit + } + + var keyHeader *types.BTCHeaderInfo if len(req.Pagination.Key) != 0 { headerHash, err := bbl.NewBTCHeaderHashBytesFromBytes(req.Pagination.Key) if err != nil { return nil, status.Error(codes.InvalidArgument, "key does not correspond to a header hash") } - prevHeader, err = k.headersState(sdkCtx).GetHeaderByHash(&headerHash) - } - - // If no tip exists or a key, then return an empty response - if prevHeader == nil { - return &types.QueryMainChainResponse{}, nil + keyHeader, err = k.headersState(sdkCtx).GetHeaderByHash(&headerHash) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "header specified by key does not exist") + } } var headers []*types.BTCHeaderInfo - headers = append(headers, prevHeader) - store := prefix.NewStore(k.headersState(sdkCtx).headers, types.HeadersObjectPrefix) + var nextKey []byte + if req.Pagination.Reverse { + var start, end uint64 + baseHeader := k.headersState(sdkCtx).GetBaseBTCHeader() + // The base header is located at the end of the mainchain + // which requires starting at the end + mainchain := k.headersState(sdkCtx).GetMainChain() + // Reverse the mainchain -- we want to retrieve results starting from the base header + bbl.Reverse(mainchain) + if keyHeader == nil { + keyHeader = baseHeader + start = 0 + } else { + start = keyHeader.Height - baseHeader.Height + } + end = start + req.Pagination.Limit - // Set this value to true to signal to FilteredPaginate to iterate the entries in reverse - req.Pagination.Reverse = true - pageRes, err := query.FilteredPaginate(store, req.Pagination, func(_ []byte, value []byte, accumulate bool) (bool, error) { - if accumulate { - headerInfo := headerInfoFromStoredBytes(k.cdc, value) - // If the previous block extends this block, then this block is part of the main chain - if prevHeader.HasParent(headerInfo) { - prevHeader = headerInfo - headers = append(headers, headerInfo) - } + if end >= uint64(len(mainchain)) { + end = uint64(len(mainchain)) } - return true, nil - }) - if err != nil { - return nil, err - } + // If the header's position on the mainchain is larger than the entire mainchain, then it is not part of the mainchain + // Also, if the element at the header's position on the mainchain is not the provided one, then it is not part of the mainchain + if start >= uint64(len(mainchain)) || !mainchain[start].Eq(keyHeader) { + return nil, status.Error(codes.InvalidArgument, "header specified by key is not a part of the mainchain") + } + headers = mainchain[start:end] + if end < uint64(len(mainchain)) { + nextKey = mainchain[end].Hash.MustMarshal() + } + } else { + tip := k.headersState(sdkCtx).GetTip() + // If there is no starting key, then the starting header is the tip + if keyHeader == nil { + keyHeader = tip + } + // This is the depth in which the start header should in the mainchain + startHeaderDepth := tip.Height - keyHeader.Height + // The depth that we want to retrieve up to + // -1 because the depth denotes how many headers have been built on top of it + depth := startHeaderDepth + req.Pagination.Limit - 1 + // Retrieve the mainchain up to the depth + mainchain := k.headersState(sdkCtx).GetMainChainUpTo(depth) + // Check whether the key provided is part of the mainchain + if uint64(len(mainchain)) <= startHeaderDepth || !mainchain[startHeaderDepth].Eq(keyHeader) { + return nil, status.Error(codes.InvalidArgument, "header specified by key is not a part of the mainchain") + } - // Override the next key attribute to point to the parent of the last header - // instead of the next element contained in the store - pageRes.NextKey = prevHeader.Header.ParentHash().MustMarshal() + // The next key is the last elements parent hash + nextKey = mainchain[len(mainchain)-1].Header.ParentHash().MustMarshal() + headers = mainchain[startHeaderDepth:] + } + pageRes := &query.PageResponse{ + NextKey: nextKey, + } + // The headers that we should return start from the depth of the start header return &types.QueryMainChainResponse{Headers: headers, Pagination: pageRes}, nil } diff --git a/x/btclightclient/keeper/grpc_query_test.go b/x/btclightclient/keeper/grpc_query_test.go index b211dde52..5d79a5577 100644 --- a/x/btclightclient/keeper/grpc_query_test.go +++ b/x/btclightclient/keeper/grpc_query_test.go @@ -1,6 +1,10 @@ package keeper_test import ( + "github.com/babylonchain/babylon/testutil/datagen" + bbl "github.com/babylonchain/babylon/types" + "github.com/cosmos/cosmos-sdk/types/query" + "math/rand" "testing" testkeeper "github.com/babylonchain/babylon/testutil/keeper" @@ -19,3 +23,320 @@ func TestParamsQuery(t *testing.T) { require.NoError(t, err) require.Equal(t, &types.QueryParamsResponse{Params: params}, response) } + +func FuzzHashesQuery(f *testing.F) { + /* + Checks: + 1. If the request is nil, an error is returned + 2. If the pagination key has not been set, + `limit` number of hashes are returned and the pagination key + has been set to the next hash. + 3. If the pagination key has been set, + the `limit` number of hashes after the key are returned. + 4. End of pagination: the last hashes are returned properly. + 5. If the pagination key is not a valid hash, an error is returned. + + Data Generation: + - Generate a random tree of headers and insert their hashes + into the hashToHeight storage. + - Generate a random `limit` to the query as an integer between 1 and the + total number of hashes. + Do checks 2-4 by initially querying without a key and then querying + with the nextKey attribute. + */ + datagen.AddRandomSeedsToFuzzer(f, 100) + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + blcKeeper, ctx := testkeeper.BTCLightClientKeeper(t) + sdkCtx := sdk.WrapSDKContext(ctx) + + // Test nil request + resp, err := blcKeeper.Hashes(sdkCtx, nil) + if resp != nil { + t.Errorf("Nil input led to a non-nil response") + } + if err == nil { + t.Errorf("Nil input led to a nil error") + } + + // Test pagination key being invalid + // We want the key to have a positive length + bzSz := datagen.RandomIntOtherThan(bbl.BTCHeaderHashLen-1, bbl.BTCHeaderHashLen*10) + 1 + key := datagen.GenRandomByteArray(bzSz) + pagination := constructRequestWithKey(key) + hashesRequest := types.NewQueryHashesRequest(pagination) + resp, err = blcKeeper.Hashes(sdkCtx, hashesRequest) + if resp != nil { + t.Errorf("Invalid key led to a non-nil response") + } + if err == nil { + t.Errorf("Invalid key led to a nil error") + } + + // Generate a random tree of headers + tree := genRandomTree(blcKeeper, ctx, 1, 10) + // Get the headers map + headersMap := tree.GetHeadersMap() + // Generate a random limit + treeSize := uint64(tree.Size()) + limit := uint64(rand.Int63n(int64(tree.Size())) + 1) + // Generate a page request with a limit and a nil key + pagination = constructRequestWithLimit(limit) + // Generate the initial query + hashesRequest = types.NewQueryHashesRequest(pagination) + // Construct a mapping from the hashes found to a boolean value + // Will be used later to evaluate whether all the hashes were returned + hashesFound := make(map[string]bool, 0) + + for headersRetrieved := uint64(0); headersRetrieved < treeSize; headersRetrieved += limit { + resp, err = blcKeeper.Hashes(sdkCtx, hashesRequest) + if err != nil { + t.Errorf("Valid request led to an error %s", err) + } + if resp == nil { + t.Fatalf("Valid request led to a nil response") + } + // If we are on the last page the elements retrieved should be equal to the remaining ones + if headersRetrieved+limit >= treeSize && uint64(len(resp.Hashes)) != treeSize-headersRetrieved { + t.Fatalf("On the last page expected %d elements but got %d", treeSize-headersRetrieved, len(resp.Hashes)) + } + // Otherwise, the elements retrieved should be equal to the limit + if headersRetrieved+limit < treeSize && uint64(len(resp.Hashes)) != limit { + t.Fatalf("On an intermediate page expected %d elements but got %d", limit, len(resp.Hashes)) + } + + for _, hash := range resp.Hashes { + // Check if the hash was generated by the tree + if _, ok := headersMap[hash.String()]; !ok { + t.Fatalf("Hashes returned a hash that was not created") + } + hashesFound[hash.String()] = true + } + + // Construct the next page request + pagination = constructRequestWithKeyAndLimit(resp.Pagination.NextKey, limit) + hashesRequest = types.NewQueryHashesRequest(pagination) + } + + if len(hashesFound) != len(headersMap) { + t.Errorf("Some hashes were missed. Got %d while %d were expected", len(hashesFound), len(headersMap)) + } + }) +} + +func FuzzContainsQuery(f *testing.F) { + /* + Checks: + 1. If the request is nil, (nil, error) is returned + 2. The query returns true or false depending on whether the hash exists. + + Data generation: + - Generate a random tree of headers and insert into storage. + - Generate a random header but do not insert it into storage. + */ + datagen.AddRandomSeedsToFuzzer(f, 100) + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + blcKeeper, ctx := testkeeper.BTCLightClientKeeper(t) + sdkCtx := sdk.WrapSDKContext(ctx) + + // Test nil input + resp, err := blcKeeper.Contains(sdkCtx, nil) + if resp != nil { + t.Errorf("Nil input led to a non-nil response") + } + if err == nil { + t.Errorf("Nil input led to a nil error") + } + + // Generate a random tree of headers and insert it into storage + tree := genRandomTree(blcKeeper, ctx, 1, 10) + + // Test with a non-existent header + query, _ := types.NewQueryContainsRequest(datagen.GenRandomBTCHeaderInfo().Hash.MarshalHex()) + resp, err = blcKeeper.Contains(sdkCtx, query) + if err != nil { + t.Errorf("Valid input let to an error: %s", err) + } + if resp == nil { + t.Errorf("Valid input led to nil response") + } + if resp.Contains { + t.Errorf("Non existent header hash led to true result") + } + + // Test with an existing header + query, _ = types.NewQueryContainsRequest(tree.RandomNode().Hash.MarshalHex()) + resp, err = blcKeeper.Contains(sdkCtx, query) + if err != nil { + t.Errorf("Valid input let to an error: %s", err) + } + if resp == nil { + t.Errorf("Valid input led to nil response") + } + if !resp.Contains { + t.Errorf("Existent header hash led to false result") + } + }) +} + +func FuzzMainChainQuery(f *testing.F) { + /* + Checks: + 1. If the request is nil, an error is returned + 2. If the pagination key is not a valid hash, an error is returned. + 3. If the pagination key does not correspond to an existing header, an error is returned. + 4. If the pagination key is not on the main chain, an error is returned. + 5. If the pagination key has not been set, + the first `limit` items of the main chain are returned + 6. If the pagination key has been set, the `limit` items after it are returned. + 7. End of pagination: the last elements are returned properly and the next_key is set to nil. + + Data Generation: + - Generate a random tree of headers with different PoW and insert them into the headers storage. + - Calculate the main chain using the `HeadersState().MainChain()` function (here we only test the query) + */ + datagen.AddRandomSeedsToFuzzer(f, 100) + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + blcKeeper, ctx := testkeeper.BTCLightClientKeeper(t) + sdkCtx := sdk.WrapSDKContext(ctx) + + // Test nil input + resp, err := blcKeeper.MainChain(sdkCtx, nil) + if resp != nil { + t.Errorf("Nil input led to a non-nil response") + } + if err == nil { + t.Errorf("Nil input led to a nil error") + } + + // Test pagination key being invalid + // We want the key to have a positive length + bzSz := datagen.RandomIntOtherThan(bbl.BTCHeaderHashLen-1, bbl.BTCHeaderHashLen*10) + 1 + key := datagen.GenRandomByteArray(bzSz) + pagination := constructRequestWithKey(key) + mainchainRequest := types.NewQueryMainChainRequest(pagination) + resp, err = blcKeeper.MainChain(sdkCtx, mainchainRequest) + if resp != nil { + t.Errorf("Invalid key led to a non-nil response") + } + if err == nil { + t.Errorf("Invalid key led to a nil error") + } + + // Generate a random tree of headers and insert it into storage + tree := genRandomTree(blcKeeper, ctx, 1, 10) + + // Check whether the key being set to an element that does not exist leads to an error + pagination = constructRequestWithKey(datagen.GenRandomBTCHeaderInfo().Hash.MustMarshal()) + mainchainRequest = types.NewQueryMainChainRequest(pagination) + resp, err = blcKeeper.MainChain(sdkCtx, mainchainRequest) + if resp != nil { + t.Errorf("Key corresponding to header that does not exist led to a non-nil response") + } + if err == nil { + t.Errorf("Key corresponding to a header that does not exist led to a nil error") + } + + // Get the mainchain + mainchain := tree.GetMainChain() + + // Check whether the key being set to a non-mainchain element leads to an error + // Select a random header + header := tree.RandomNode() + // Get the tip + tip := tree.GetTip() + // if the header is not on the mainchain, we can test our assumption + // if it is, randomness will ensure that it does on another test case + if !tree.IsOnNodeChain(tip, header) { + pagination = constructRequestWithKeyAndLimit(header.Hash.MustMarshal(), uint64(len(mainchain))) + mainchainRequest = types.NewQueryMainChainRequest(pagination) + resp, err = blcKeeper.MainChain(sdkCtx, mainchainRequest) + if resp != nil { + t.Errorf("Key corresponding to header that is not on the mainchain led to a non-nil response") + } + if err == nil { + t.Errorf("Key corresponding to a header that is not on the mainchain led to a nil error") + } + } + + // Index into the current element of mainchain that we are iterating + mcIdx := 0 + // Generate a random limit + mcSize := uint64(len(mainchain)) + limit := uint64(rand.Int63n(int64(len(mainchain))) + 1) + + // 50% of the time, do a reverse request + // Generate a page request with a limit and a nil key + pagination = constructRequestWithLimit(limit) + reverse := false + if datagen.OneInN(2) { + reverse = true + pagination.Reverse = true + } + // Generate the initial query + mainchainRequest = types.NewQueryMainChainRequest(pagination) + for headersRetrieved := uint64(0); headersRetrieved < mcSize; headersRetrieved += limit { + resp, err = blcKeeper.MainChain(sdkCtx, mainchainRequest) + if err != nil { + t.Errorf("Valid request led to an error %s", err) + } + if resp == nil { + t.Fatalf("Valid request led to nil response") + } + // If we are on the last page the elements retrieved should be equal to the remaining ones + if headersRetrieved+limit >= mcSize && uint64(len(resp.Headers)) != mcSize-headersRetrieved { + t.Fatalf("On the last page expected %d elements but got %d", mcSize-headersRetrieved, len(resp.Headers)) + } + // Otherwise, the elements retrieved should be equal to the limit + if headersRetrieved+limit < mcSize && uint64(len(resp.Headers)) != limit { + t.Fatalf("On an intermediate page expected %d elements but got %d", limit, len(resp.Headers)) + } + + // Iterate through the headers and ensure that they correspond + // to the current index into the mainchain. + for i := 0; i < len(resp.Headers); i++ { + idx := mcIdx + if reverse { + idx = len(mainchain) - mcIdx - 1 + } + if !resp.Headers[i].Eq(mainchain[idx]) { + t.Errorf("%t", reverse) + t.Errorf("Response does not match mainchain. Expected %s got %s", mainchain[idx].Hash, resp.Headers[i].Hash) + } + mcIdx += 1 + } + + // Construct the next page request + pagination = constructRequestWithKeyAndLimit(resp.Pagination.NextKey, limit) + if reverse { + pagination.Reverse = true + } + mainchainRequest = types.NewQueryMainChainRequest(pagination) + } + }) +} + +// Constructors for PageRequest objects +func constructRequestWithKeyAndLimit(key []byte, limit uint64) *query.PageRequest { + // If limit is 0, set one randomly + if limit == 0 { + limit = uint64(rand.Int63() + 1) // Use Int63 instead of Uint64 to avoid overflows + } + return &query.PageRequest{ + Key: key, + Offset: 0, // only offset or key is set + Limit: limit, + CountTotal: false, // only used when offset is used + Reverse: false, + } +} + +func constructRequestWithLimit(limit uint64) *query.PageRequest { + return constructRequestWithKeyAndLimit(nil, limit) +} + +func constructRequestWithKey(key []byte) *query.PageRequest { + return constructRequestWithKeyAndLimit(key, 0) +}