diff --git a/access/grpc/client.go b/access/grpc/client.go index a7aa3d7ba..54f5da838 100644 --- a/access/grpc/client.go +++ b/access/grpc/client.go @@ -268,6 +268,14 @@ func (c *Client) GetLatestProtocolStateSnapshot(ctx context.Context) ([]byte, er return c.grpc.GetLatestProtocolStateSnapshot(ctx) } +func (c *Client) GetProtocolStateSnapshotByBlockID(ctx context.Context, blockID flow.Identifier) ([]byte, error) { + return c.grpc.GetProtocolStateSnapshotByBlockID(ctx, blockID) +} + +func (c *Client) GetProtocolStateSnapshotByHeight(ctx context.Context, blockHeight uint64) ([]byte, error) { + return c.grpc.GetProtocolStateSnapshotByHeight(ctx, blockHeight) +} + func (c *Client) GetExecutionResultForBlockID(ctx context.Context, blockID flow.Identifier) (*flow.ExecutionResult, error) { return c.grpc.GetExecutionResultForBlockID(ctx, blockID) } diff --git a/access/grpc/grpc.go b/access/grpc/grpc.go index ca6aff34f..5540d7ece 100644 --- a/access/grpc/grpc.go +++ b/access/grpc/grpc.go @@ -910,6 +910,32 @@ func (c *BaseClient) GetLatestProtocolStateSnapshot(ctx context.Context, opts .. return res.GetSerializedSnapshot(), nil } +func (c *BaseClient) GetProtocolStateSnapshotByBlockID(ctx context.Context, blockID flow.Identifier, opts ...grpc.CallOption) ([]byte, error) { + req := &access.GetProtocolStateSnapshotByBlockIDRequest{ + BlockId: blockID.Bytes(), + } + + res, err := c.rpcClient.GetProtocolStateSnapshotByBlockID(ctx, req, opts...) + if err != nil { + return nil, newRPCError(err) + } + + return res.GetSerializedSnapshot(), nil +} + +func (c *BaseClient) GetProtocolStateSnapshotByHeight(ctx context.Context, blockHeight uint64, opts ...grpc.CallOption) ([]byte, error) { + req := &access.GetProtocolStateSnapshotByHeightRequest{ + BlockHeight: blockHeight, + } + + res, err := c.rpcClient.GetProtocolStateSnapshotByHeight(ctx, req, opts...) + if err != nil { + return nil, newRPCError(err) + } + + return res.GetSerializedSnapshot(), nil +} + func (c *BaseClient) GetExecutionResultForBlockID(ctx context.Context, blockID flow.Identifier, opts ...grpc.CallOption) (*flow.ExecutionResult, error) { er, err := c.rpcClient.GetExecutionResultForBlockID(ctx, &access.GetExecutionResultForBlockIDRequest{ BlockId: convert.IdentifierToMessage(blockID), diff --git a/access/grpc/grpc_test.go b/access/grpc/grpc_test.go index 3ea6fb980..a88169b50 100644 --- a/access/grpc/grpc_test.go +++ b/access/grpc/grpc_test.go @@ -1500,6 +1500,64 @@ func TestClient_GetLatestProtocolStateSnapshot(t *testing.T) { })) } +func TestClient_GetProtocolStateSnapshotByBlockID(t *testing.T) { + ids := test.IdentifierGenerator() + + t.Run("Success", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + blockID := ids.New() + + expected := &access.ProtocolStateSnapshotResponse{ + SerializedSnapshot: make([]byte, 128), + } + _, err := rand.Read(expected.SerializedSnapshot) + assert.NoError(t, err) + + rpc.On("GetProtocolStateSnapshotByBlockID", ctx, mock.Anything).Return(expected, nil) + + res, err := c.GetProtocolStateSnapshotByBlockID(ctx, blockID) + assert.NoError(t, err) + assert.Equal(t, expected.SerializedSnapshot, res) + })) + + t.Run("Internal error", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + blockID := ids.New() + + rpc.On("GetProtocolStateSnapshotByBlockID", ctx, mock.Anything). + Return(nil, errInternal) + + _, err := c.GetProtocolStateSnapshotByBlockID(ctx, blockID) + assert.Error(t, err) + assert.Equal(t, codes.Internal, status.Code(err)) + })) +} + +func TestClient_GetProtocolStateSnapshotByHeight(t *testing.T) { + blockHeight := uint64(42) + + t.Run("Success", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + expected := &access.ProtocolStateSnapshotResponse{ + SerializedSnapshot: make([]byte, 128), + } + _, err := rand.Read(expected.SerializedSnapshot) + assert.NoError(t, err) + + rpc.On("GetProtocolStateSnapshotByHeight", ctx, mock.Anything).Return(expected, nil) + + res, err := c.GetProtocolStateSnapshotByHeight(ctx, blockHeight) + assert.NoError(t, err) + assert.Equal(t, expected.SerializedSnapshot, res) + })) + + t.Run("Internal error", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + rpc.On("GetProtocolStateSnapshotByHeight", ctx, mock.Anything). + Return(nil, errInternal) + + _, err := c.GetProtocolStateSnapshotByHeight(ctx, blockHeight) + assert.Error(t, err) + assert.Equal(t, codes.Internal, status.Code(err)) + })) +} + func TestClient_GetExecutionResultForBlockID(t *testing.T) { ids := test.IdentifierGenerator() t.Run("Success", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) {