diff --git a/access/api.go b/access/api.go index adeb7284c10..acd4d6138b8 100644 --- a/access/api.go +++ b/access/api.go @@ -203,10 +203,50 @@ type API interface { // // If invalid parameters will be supplied SubscribeBlockDigestsFromLatest will return a failed subscription. SubscribeBlockDigestsFromLatest(ctx context.Context, blockStatus flow.BlockStatus) subscription.Subscription - // SubscribeTransactionStatuses streams transaction statuses starting from the reference block saved in the - // transaction itself until the block containing the transaction becomes sealed or expired. When the transaction - // status becomes TransactionStatusSealed or TransactionStatusExpired, the subscription will automatically shut down. - SubscribeTransactionStatuses(ctx context.Context, tx *flow.TransactionBody, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription + // SubscribeTransactionStatusesFromStartBlockID subscribes to transaction status updates for a given transaction ID. + // Monitoring begins from the specified block ID. The subscription streams status updates until the transaction + // reaches a final state (TransactionStatusSealed or TransactionStatusExpired). When the transaction reaches one of + // these final statuses, the subscription will automatically terminate. + // + // Parameters: + // - ctx: The context to manage the subscription's lifecycle, including cancellation. + // - txID: The identifier of the transaction to monitor. + // - startBlockID: The block ID from which to start monitoring. + // - requiredEventEncodingVersion: The version of event encoding required for the subscription. + SubscribeTransactionStatusesFromStartBlockID(ctx context.Context, txID flow.Identifier, startBlockID flow.Identifier, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription + // SubscribeTransactionStatusesFromStartHeight subscribes to transaction status updates for a given transaction ID. + // Monitoring begins from the specified block height. The subscription streams status updates until the transaction + // reaches a final state (TransactionStatusSealed or TransactionStatusExpired). When the transaction reaches one of + // these final statuses, the subscription will automatically terminate. + // + // Parameters: + // - ctx: The context to manage the subscription's lifecycle, including cancellation. + // - txID: The unique identifier of the transaction to monitor. + // - startHeight: The block height from which to start monitoring. + // - requiredEventEncodingVersion: The version of event encoding required for the subscription. + SubscribeTransactionStatusesFromStartHeight(ctx context.Context, txID flow.Identifier, startHeight uint64, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription + // SubscribeTransactionStatusesFromLatest subscribes to transaction status updates for a given transaction ID. + // Monitoring begins from the latest block. The subscription streams status updates until the transaction + // reaches a final state (TransactionStatusSealed or TransactionStatusExpired). When the transaction reaches one of + // these final statuses, the subscription will automatically terminate. + // + // Parameters: + // - ctx: The context to manage the subscription's lifecycle, including cancellation. + // - txID: The unique identifier of the transaction to monitor. + // - requiredEventEncodingVersion: The version of event encoding required for the subscription. + SubscribeTransactionStatusesFromLatest(ctx context.Context, txID flow.Identifier, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription + // SendAndSubscribeTransactionStatuses sends a transaction to the execution node and subscribes to its status updates. + // Monitoring begins from the reference block saved in the transaction itself and streams status updates until the transaction + // reaches a final state (TransactionStatusSealed or TransactionStatusExpired). Once a final status is reached, the subscription + // automatically terminates. + // + // Parameters: + // - ctx: The context to manage the transaction sending and subscription lifecycle, including cancellation. + // - tx: The transaction body to be sent and monitored. + // - requiredEventEncodingVersion: The version of event encoding required for the subscription. + // + // If the transaction cannot be sent, the subscription will fail and return a failed subscription. + SendAndSubscribeTransactionStatuses(ctx context.Context, tx *flow.TransactionBody, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription } // TODO: Combine this with flow.TransactionResult? diff --git a/access/handler.go b/access/handler.go index b974e7034fc..bcf401a2884 100644 --- a/access/handler.go +++ b/access/handler.go @@ -1425,12 +1425,7 @@ func (h *Handler) SendAndSubscribeTransactionStatuses( return status.Error(codes.InvalidArgument, err.Error()) } - err = h.api.SendTransaction(ctx, &tx) - if err != nil { - return err - } - - sub := h.api.SubscribeTransactionStatuses(ctx, &tx, request.GetEventEncodingVersion()) + sub := h.api.SendAndSubscribeTransactionStatuses(ctx, &tx, request.GetEventEncodingVersion()) messageIndex := counters.NewMonotonousCounter(0) return subscription.HandleRPCSubscription(sub, func(txResults []*TransactionResult) error { diff --git a/access/mock/api.go b/access/mock/api.go index eaaf6c428f2..13c35b293d3 100644 --- a/access/mock/api.go +++ b/access/mock/api.go @@ -1145,6 +1145,26 @@ func (_m *API) Ping(ctx context.Context) error { return r0 } +// SendAndSubscribeTransactionStatuses provides a mock function with given fields: ctx, tx, requiredEventEncodingVersion +func (_m *API) SendAndSubscribeTransactionStatuses(ctx context.Context, tx *flow.TransactionBody, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription { + ret := _m.Called(ctx, tx, requiredEventEncodingVersion) + + if len(ret) == 0 { + panic("no return value specified for SendAndSubscribeTransactionStatuses") + } + + var r0 subscription.Subscription + if rf, ok := ret.Get(0).(func(context.Context, *flow.TransactionBody, entities.EventEncodingVersion) subscription.Subscription); ok { + r0 = rf(ctx, tx, requiredEventEncodingVersion) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(subscription.Subscription) + } + } + + return r0 +} + // SendTransaction provides a mock function with given fields: ctx, tx func (_m *API) SendTransaction(ctx context.Context, tx *flow.TransactionBody) error { ret := _m.Called(ctx, tx) @@ -1343,17 +1363,57 @@ func (_m *API) SubscribeBlocksFromStartHeight(ctx context.Context, startHeight u return r0 } -// SubscribeTransactionStatuses provides a mock function with given fields: ctx, tx, requiredEventEncodingVersion -func (_m *API) SubscribeTransactionStatuses(ctx context.Context, tx *flow.TransactionBody, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription { - ret := _m.Called(ctx, tx, requiredEventEncodingVersion) +// SubscribeTransactionStatusesFromLatest provides a mock function with given fields: ctx, txID, requiredEventEncodingVersion +func (_m *API) SubscribeTransactionStatusesFromLatest(ctx context.Context, txID flow.Identifier, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription { + ret := _m.Called(ctx, txID, requiredEventEncodingVersion) if len(ret) == 0 { - panic("no return value specified for SubscribeTransactionStatuses") + panic("no return value specified for SubscribeTransactionStatusesFromLatest") } var r0 subscription.Subscription - if rf, ok := ret.Get(0).(func(context.Context, *flow.TransactionBody, entities.EventEncodingVersion) subscription.Subscription); ok { - r0 = rf(ctx, tx, requiredEventEncodingVersion) + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, entities.EventEncodingVersion) subscription.Subscription); ok { + r0 = rf(ctx, txID, requiredEventEncodingVersion) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(subscription.Subscription) + } + } + + return r0 +} + +// SubscribeTransactionStatusesFromStartBlockID provides a mock function with given fields: ctx, txID, startBlockID, requiredEventEncodingVersion +func (_m *API) SubscribeTransactionStatusesFromStartBlockID(ctx context.Context, txID flow.Identifier, startBlockID flow.Identifier, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription { + ret := _m.Called(ctx, txID, startBlockID, requiredEventEncodingVersion) + + if len(ret) == 0 { + panic("no return value specified for SubscribeTransactionStatusesFromStartBlockID") + } + + var r0 subscription.Subscription + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, flow.Identifier, entities.EventEncodingVersion) subscription.Subscription); ok { + r0 = rf(ctx, txID, startBlockID, requiredEventEncodingVersion) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(subscription.Subscription) + } + } + + return r0 +} + +// SubscribeTransactionStatusesFromStartHeight provides a mock function with given fields: ctx, txID, startHeight, requiredEventEncodingVersion +func (_m *API) SubscribeTransactionStatusesFromStartHeight(ctx context.Context, txID flow.Identifier, startHeight uint64, requiredEventEncodingVersion entities.EventEncodingVersion) subscription.Subscription { + ret := _m.Called(ctx, txID, startHeight, requiredEventEncodingVersion) + + if len(ret) == 0 { + panic("no return value specified for SubscribeTransactionStatusesFromStartHeight") + } + + var r0 subscription.Subscription + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64, entities.EventEncodingVersion) subscription.Subscription); ok { + r0 = rf(ctx, txID, startHeight, requiredEventEncodingVersion) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(subscription.Subscription) diff --git a/cmd/util/cmd/run-script/cmd.go b/cmd/util/cmd/run-script/cmd.go index 171f97e76b7..59646d0687a 100644 --- a/cmd/util/cmd/run-script/cmd.go +++ b/cmd/util/cmd/run-script/cmd.go @@ -532,7 +532,32 @@ func (*api) SubscribeBlockDigestsFromLatest( return nil } -func (*api) SubscribeTransactionStatuses( +func (a *api) SubscribeTransactionStatusesFromStartBlockID( + _ context.Context, + _ flow.Identifier, + _ flow.Identifier, + _ entities.EventEncodingVersion, +) subscription.Subscription { + return nil +} + +func (a *api) SubscribeTransactionStatusesFromStartHeight( + _ context.Context, + _ flow.Identifier, + _ uint64, + _ entities.EventEncodingVersion, +) subscription.Subscription { + return nil +} + +func (a *api) SubscribeTransactionStatusesFromLatest( + _ context.Context, + _ flow.Identifier, + _ entities.EventEncodingVersion, +) subscription.Subscription { + return nil +} +func (a *api) SendAndSubscribeTransactionStatuses( _ context.Context, _ *flow.TransactionBody, _ entities.EventEncodingVersion, diff --git a/engine/access/rest/http/request/event_type.go b/engine/access/rest/common/parser/event_type.go similarity index 98% rename from engine/access/rest/http/request/event_type.go rename to engine/access/rest/common/parser/event_type.go index c3f425d81c8..f1ba7ca1acb 100644 --- a/engine/access/rest/http/request/event_type.go +++ b/engine/access/rest/common/parser/event_type.go @@ -1,4 +1,4 @@ -package request +package parser import ( "fmt" diff --git a/engine/access/rest/http/request/get_events.go b/engine/access/rest/http/request/get_events.go index c864cf24a47..dee55f98ded 100644 --- a/engine/access/rest/http/request/get_events.go +++ b/engine/access/rest/http/request/get_events.go @@ -71,7 +71,7 @@ func (g *GetEvents) Parse(rawType string, rawStart string, rawEnd string, rawBlo if rawType == "" { return fmt.Errorf("event type must be provided") } - var eventType EventType + var eventType parser.EventType err = eventType.Parse(rawType) if err != nil { return err diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 4f0e2260ae5..98643f19638 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -51,7 +51,14 @@ func NewServer(serverAPI access.API, builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } - dataProviderFactory := dp.NewDataProviderFactory(logger, stateStreamApi, serverAPI) + dataProviderFactory := dp.NewDataProviderFactory( + logger, + stateStreamApi, + serverAPI, + chain, + stateStreamConfig.EventFilterConfig, + stateStreamConfig.HeartbeatInterval, + ) builder.AddWebsocketsRoute(chain, wsConfig, config.MaxRequestSize, dataProviderFactory) c := cors.New(cors.Options{ diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider.go b/engine/access/rest/websockets/data_providers/account_statuses_provider.go new file mode 100644 index 00000000000..396dcbc7b9a --- /dev/null +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider.go @@ -0,0 +1,176 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/counters" +) + +// accountStatusesArguments contains the arguments required for subscribing to account statuses +type accountStatusesArguments struct { + StartBlockID flow.Identifier // ID of the block to start subscription from + StartBlockHeight uint64 // Height of the block to start subscription from + Filter state_stream.AccountStatusFilter // Filter applied to events for a given subscription +} + +type AccountStatusesDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + stateStreamApi state_stream.API + + heartbeatInterval uint64 +} + +var _ DataProvider = (*AccountStatusesDataProvider)(nil) + +// NewAccountStatusesDataProvider creates a new instance of AccountStatusesDataProvider. +func NewAccountStatusesDataProvider( + ctx context.Context, + logger zerolog.Logger, + stateStreamApi state_stream.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, + heartbeatInterval uint64, +) (*AccountStatusesDataProvider, error) { + p := &AccountStatusesDataProvider{ + logger: logger.With().Str("component", "account-statuses-data-provider").Logger(), + stateStreamApi: stateStreamApi, + heartbeatInterval: heartbeatInterval, + } + + // Initialize arguments passed to the provider. + accountStatusesArgs, err := parseAccountStatusesArguments(arguments, chain, eventFilterConfig) + if err != nil { + return nil, fmt.Errorf("invalid arguments for account statuses data provider: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, accountStatusesArgs), // Set up a subscription to account statuses based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for events and handles responses. +// +// No errors are expected during normal operations. +func (p *AccountStatusesDataProvider) Run() error { + return subscription.HandleSubscription(p.subscription, p.handleResponse()) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *AccountStatusesDataProvider) createSubscription(ctx context.Context, args accountStatusesArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.stateStreamApi.SubscribeAccountStatusesFromStartBlockID(ctx, args.StartBlockID, args.Filter) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.stateStreamApi.SubscribeAccountStatusesFromStartHeight(ctx, args.StartBlockHeight, args.Filter) + } + + return p.stateStreamApi.SubscribeAccountStatusesFromLatestBlock(ctx, args.Filter) +} + +// handleResponse processes an account statuses and sends the formatted response. +// +// No errors are expected during normal operations. +func (p *AccountStatusesDataProvider) handleResponse() func(accountStatusesResponse *backend.AccountStatusesResponse) error { + blocksSinceLastMessage := uint64(0) + messageIndex := counters.NewMonotonousCounter(0) + + return func(accountStatusesResponse *backend.AccountStatusesResponse) error { + // check if there are any events in the response. if not, do not send a message unless the last + // response was more than HeartbeatInterval blocks ago + if len(accountStatusesResponse.AccountEvents) == 0 { + blocksSinceLastMessage++ + if blocksSinceLastMessage < p.heartbeatInterval { + return nil + } + blocksSinceLastMessage = 0 + } + + index := messageIndex.Value() + if ok := messageIndex.Set(messageIndex.Value() + 1); !ok { + return status.Errorf(codes.Internal, "message index already incremented to %d", messageIndex.Value()) + } + + p.send <- &models.AccountStatusesResponse{ + BlockID: accountStatusesResponse.BlockID.String(), + Height: strconv.FormatUint(accountStatusesResponse.Height, 10), + AccountEvents: accountStatusesResponse.AccountEvents, + MessageIndex: index, + } + + return nil + } +} + +// parseAccountStatusesArguments validates and initializes the account statuses arguments. +func parseAccountStatusesArguments( + arguments models.Arguments, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, +) (accountStatusesArguments, error) { + var args accountStatusesArguments + + // Parse block arguments + startBlockID, startBlockHeight, err := ParseStartBlock(arguments) + if err != nil { + return args, err + } + args.StartBlockID = startBlockID + args.StartBlockHeight = startBlockHeight + + // Parse 'event_types' as a JSON array + var eventTypes parser.EventTypes + if eventTypesIn, ok := arguments["event_types"]; ok && eventTypesIn != "" { + result, ok := eventTypesIn.([]string) + if !ok { + return args, fmt.Errorf("'event_types' must be an array of string") + } + + err := eventTypes.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'event_types': %w", err) + } + } + + // Parse 'accountAddresses' as []string{} + var accountAddresses []string + if accountAddressesIn, ok := arguments["account_addresses"]; ok && accountAddressesIn != "" { + accountAddresses, ok = accountAddressesIn.([]string) + if !ok { + return args, fmt.Errorf("'account_addresses' must be an array of string") + } + } + + // Initialize the event filter with the parsed arguments + args.Filter, err = state_stream.NewAccountStatusFilter(eventFilterConfig, chain, eventTypes.Flow(), accountAddresses) + if err != nil { + return args, fmt.Errorf("failed to create event filter: %w", err) + } + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go new file mode 100644 index 00000000000..8f689ca034a --- /dev/null +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go @@ -0,0 +1,269 @@ +package data_providers + +import ( + "context" + "strconv" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + ssmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +// AccountStatusesProviderSuite is a test suite for testing the account statuses providers functionality. +type AccountStatusesProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *ssmock.API + + chain flow.Chain + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestNewAccountStatusesDataProvider(t *testing.T) { + suite.Run(t, new(AccountStatusesProviderSuite)) +} + +func (s *AccountStatusesProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = ssmock.NewAPI(s.T()) + + s.chain = flow.Testnet.Chain() + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + + s.factory = NewDataProviderFactory( + s.log, + s.api, + nil, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(s.factory) +} + +// TestAccountStatusesDataProvider_HappyPath tests the behavior of the account statuses data provider +// when it is configured correctly and operating under normal conditions. It +// validates that events are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_HappyPath() { + + expectedEvents := []flow.Event{ + unittest.EventFixture(state_stream.CoreEventAccountCreated, 0, 0, unittest.IdentifierFixture(), 0), + unittest.EventFixture(state_stream.CoreEventAccountKeyAdded, 0, 0, unittest.IdentifierFixture(), 0), + } + + var expectedAccountStatusesResponses []backend.AccountStatusesResponse + for i := 0; i < len(expectedEvents); i++ { + expectedAccountStatusesResponses = append(expectedAccountStatusesResponses, backend.AccountStatusesResponse{ + Height: s.rootBlock.Header.Height, + BlockID: s.rootBlock.ID(), + AccountEvents: map[string]flow.EventsList{ + unittest.RandomAddressFixture().String(): expectedEvents, + }, + }) + } + + testHappyPath( + s.T(), + AccountStatusesTopic, + s.factory, + s.subscribeAccountStatusesDataProviderTestCases(), + func(dataChan chan interface{}) { + for i := 0; i < len(expectedAccountStatusesResponses); i++ { + dataChan <- &expectedAccountStatusesResponses[i] + } + }, + expectedAccountStatusesResponses, + s.requireAccountStatuses, + ) +} + +func (s *AccountStatusesProviderSuite) subscribeAccountStatusesDataProviderTestCases() []testType { + return []testType{ + { + name: "SubscribeAccountStatusesFromStartBlockID happy path", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "event_types": []string{"flow.AccountCreated", "flow.AccountKeyAdded"}, + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeAccountStatusesFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeAccountStatusesFromStartHeight happy path", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeAccountStatusesFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeAccountStatusesFromLatestBlock happy path", + arguments: models.Arguments{}, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeAccountStatusesFromLatestBlock", + mock.Anything, + mock.Anything, + ).Return(sub).Once() + }, + }, + } +} + +// requireAccountStatuses ensures that the received account statuses information matches the expected data. +func (s *AccountStatusesProviderSuite) requireAccountStatuses( + v interface{}, + expectedResponse interface{}, +) { + expectedAccountStatusesResponse, ok := expectedResponse.(backend.AccountStatusesResponse) + require.True(s.T(), ok, "unexpected type: %T", expectedResponse) + + actualResponse, ok := v.(*models.AccountStatusesResponse) + require.True(s.T(), ok, "Expected *models.AccountStatusesResponse, got %T", v) + + require.Equal(s.T(), expectedAccountStatusesResponse.BlockID.String(), actualResponse.BlockID) + require.Equal(s.T(), len(expectedAccountStatusesResponse.AccountEvents), len(actualResponse.AccountEvents)) + + for key, expectedEvents := range expectedAccountStatusesResponse.AccountEvents { + actualEvents, ok := actualResponse.AccountEvents[key] + require.True(s.T(), ok, "Missing key in actual AccountEvents: %s", key) + + s.Require().Equal(expectedEvents, actualEvents, "Mismatch for key: %s", key) + } +} + +// TestAccountStatusesDataProvider_InvalidArguments tests the behavior of the account statuses data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Providing both 'start_block_id' and 'start_block_height' simultaneously. +// 2. Invalid 'start_block_id' argument. +// 3. Invalid 'start_block_height' argument. +func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := AccountStatusesTopic + + for _, test := range invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewAccountStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + test.arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// TestMessageIndexAccountStatusesProviderResponse_HappyPath tests that MessageIndex values in response are strictly increasing. +func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderResponse_HappyPath() { + ctx := context.Background() + send := make(chan interface{}, 10) + topic := AccountStatusesTopic + accountStatusesCount := 4 + + // Create a channel to simulate the subscription's account statuses channel + accountStatusesChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(accountStatusesChan)) + sub.On("Err").Return(nil) + + s.api.On("SubscribeAccountStatusesFromStartBlockID", mock.Anything, mock.Anything, mock.Anything).Return(sub) + + arguments := + map[string]interface{}{ + "start_block_id": s.rootBlock.ID().String(), + } + + // Create the AccountStatusesDataProvider instance + provider, err := NewAccountStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine to simulate subscription processing + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting data to the account statuses channel + go func() { + defer close(accountStatusesChan) // Close the channel when done + + for i := 0; i < accountStatusesCount; i++ { + accountStatusesChan <- &backend.AccountStatusesResponse{} + } + }() + + // Collect responses + var responses []*models.AccountStatusesResponse + for i := 0; i < accountStatusesCount; i++ { + res := <-send + accountStatusesRes, ok := res.(*models.AccountStatusesResponse) + s.Require().True(ok, "Expected *models.AccountStatusesResponse, got %T", res) + responses = append(responses, accountStatusesRes) + } + + // Verifying that indices are starting from 0 + s.Require().Equal(uint64(0), responses[0].MessageIndex, "Expected MessageIndex to start with 0") + + // Verifying that indices are strictly increasing + for i := 1; i < len(responses); i++ { + prevIndex := responses[i-1].MessageIndex + currentIndex := responses[i].MessageIndex + s.Require().Equal(prevIndex+1, currentIndex, "Expected MessageIndex to increment by 1") + } + + // Ensure the provider is properly closed after the test + provider.Close() +} diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider.go b/engine/access/rest/websockets/data_providers/block_digests_provider.go index 1fa3f7a6dc7..80307be6b64 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider.go @@ -69,7 +69,7 @@ func (p *BlockDigestsDataProvider) Run() error { } // createSubscription creates a new subscription using the specified input arguments. -func (p *BlockDigestsDataProvider) createSubscription(ctx context.Context, args BlocksArguments) subscription.Subscription { +func (p *BlockDigestsDataProvider) createSubscription(ctx context.Context, args blocksArguments) subscription.Subscription { if args.StartBlockID != flow.ZeroID { return p.api.SubscribeBlockDigestsFromStartBlockID(ctx, args.StartBlockID, args.BlockStatus) } diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go index 476edf77111..975716c74af 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go @@ -106,20 +106,26 @@ func (s *BlockDigestsProviderSuite) validBlockDigestsArgumentsTestCases() []test // validates that block digests are correctly streamed to the channel and ensures // no unexpected errors occur. func (s *BlockDigestsProviderSuite) TestBlockDigestsDataProvider_HappyPath() { - s.testHappyPath( + testHappyPath( + s.T(), BlockDigestsTopic, + s.factory, s.validBlockDigestsArgumentsTestCases(), - func(dataChan chan interface{}, blocks []*flow.Block) { - for _, block := range blocks { + func(dataChan chan interface{}) { + for _, block := range s.blocks { dataChan <- flow.NewBlockDigest(block.Header.ID(), block.Header.Height, block.Header.Timestamp) } }, - s.requireBlockDigests, + s.blocks, + s.requireBlockDigest, ) } // requireBlockHeaders ensures that the received block header information matches the expected data. -func (s *BlocksProviderSuite) requireBlockDigests(v interface{}, expectedBlock *flow.Block) { +func (s *BlocksProviderSuite) requireBlockDigest(v interface{}, expected interface{}) { + expectedBlock, ok := expected.(*flow.Block) + require.True(s.T(), ok, "unexpected type: %T", v) + actualResponse, ok := v.(*models.BlockDigestMessageResponse) require.True(s.T(), ok, "unexpected response type: %T", v) diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider.go b/engine/access/rest/websockets/data_providers/block_headers_provider.go index 4f9e29e2428..4fddeb499f2 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider.go @@ -69,7 +69,7 @@ func (p *BlockHeadersDataProvider) Run() error { } // createSubscription creates a new subscription using the specified input arguments. -func (p *BlockHeadersDataProvider) createSubscription(ctx context.Context, args BlocksArguments) subscription.Subscription { +func (p *BlockHeadersDataProvider) createSubscription(ctx context.Context, args blocksArguments) subscription.Subscription { if args.StartBlockID != flow.ZeroID { return p.api.SubscribeBlockHeadersFromStartBlockID(ctx, args.StartBlockID, args.BlockStatus) } diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go index 57c262d8795..b929a46d076 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go @@ -106,20 +106,26 @@ func (s *BlockHeadersProviderSuite) validBlockHeadersArgumentsTestCases() []test // validates that block headers are correctly streamed to the channel and ensures // no unexpected errors occur. func (s *BlockHeadersProviderSuite) TestBlockHeadersDataProvider_HappyPath() { - s.testHappyPath( + testHappyPath( + s.T(), BlockHeadersTopic, + s.factory, s.validBlockHeadersArgumentsTestCases(), - func(dataChan chan interface{}, blocks []*flow.Block) { - for _, block := range blocks { + func(dataChan chan interface{}) { + for _, block := range s.blocks { dataChan <- block.Header } }, - s.requireBlockHeaders, + s.blocks, + s.requireBlockHeader, ) } // requireBlockHeaders ensures that the received block header information matches the expected data. -func (s *BlockHeadersProviderSuite) requireBlockHeaders(v interface{}, expectedBlock *flow.Block) { +func (s *BlockHeadersProviderSuite) requireBlockHeader(v interface{}, expected interface{}) { + expectedBlock, ok := expected.(*flow.Block) + require.True(s.T(), ok, "unexpected type: %T", v) + actualResponse, ok := v.(*models.BlockHeaderMessageResponse) require.True(s.T(), ok, "unexpected response type: %T", v) diff --git a/engine/access/rest/websockets/data_providers/blocks_provider.go b/engine/access/rest/websockets/data_providers/blocks_provider.go index 72cfaa6f554..6c09c4a623a 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider.go @@ -16,7 +16,7 @@ import ( ) // BlocksArguments contains the arguments required for subscribing to blocks / block headers / block digests -type BlocksArguments struct { +type blocksArguments struct { StartBlockID flow.Identifier // ID of the block to start subscription from StartBlockHeight uint64 // Height of the block to start subscription from BlockStatus flow.BlockStatus // Status of blocks to subscribe to @@ -78,7 +78,7 @@ func (p *BlocksDataProvider) Run() error { } // createSubscription creates a new subscription using the specified input arguments. -func (p *BlocksDataProvider) createSubscription(ctx context.Context, args BlocksArguments) subscription.Subscription { +func (p *BlocksDataProvider) createSubscription(ctx context.Context, args blocksArguments) subscription.Subscription { if args.StartBlockID != flow.ZeroID { return p.api.SubscribeBlocksFromStartBlockID(ctx, args.StartBlockID, args.BlockStatus) } @@ -91,12 +91,16 @@ func (p *BlocksDataProvider) createSubscription(ctx context.Context, args Blocks } // ParseBlocksArguments validates and initializes the blocks arguments. -func ParseBlocksArguments(arguments models.Arguments) (BlocksArguments, error) { - var args BlocksArguments +func ParseBlocksArguments(arguments models.Arguments) (blocksArguments, error) { + var args blocksArguments // Parse 'block_status' if blockStatusIn, ok := arguments["block_status"]; ok { - blockStatus, err := parser.ParseBlockStatus(blockStatusIn) + result, ok := blockStatusIn.(string) + if !ok { + return args, fmt.Errorf("'block_status' must be string") + } + blockStatus, err := parser.ParseBlockStatus(result) if err != nil { return args, err } @@ -105,34 +109,52 @@ func ParseBlocksArguments(arguments models.Arguments) (BlocksArguments, error) { return args, fmt.Errorf("'block_status' must be provided") } + // Parse block arguments + startBlockID, startBlockHeight, err := ParseStartBlock(arguments) + if err != nil { + return args, err + } + args.StartBlockID = startBlockID + args.StartBlockHeight = startBlockHeight + + return args, nil +} + +func ParseStartBlock(arguments models.Arguments) (flow.Identifier, uint64, error) { startBlockIDIn, hasStartBlockID := arguments["start_block_id"] startBlockHeightIn, hasStartBlockHeight := arguments["start_block_height"] - // Ensure only one of start_block_id or start_block_height is provided + // Check for mutual exclusivity of start_block_id and start_block_height early if hasStartBlockID && hasStartBlockHeight { - return args, fmt.Errorf("can only provide either 'start_block_id' or 'start_block_height'") + return flow.ZeroID, 0, fmt.Errorf("can only provide either 'start_block_id' or 'start_block_height'") } - // Parse 'start_block_id' if provided + // Parse 'start_block_id' if hasStartBlockID { + result, ok := startBlockIDIn.(string) + if !ok { + return flow.ZeroID, request.EmptyHeight, fmt.Errorf("'start_block_id' must be a string") + } var startBlockID parser.ID - err := startBlockID.Parse(startBlockIDIn) + err := startBlockID.Parse(result) if err != nil { - return args, err + return flow.ZeroID, request.EmptyHeight, fmt.Errorf("invalid 'start_block_id': %w", err) } - args.StartBlockID = startBlockID.Flow() + return startBlockID.Flow(), request.EmptyHeight, nil } - // Parse 'start_block_height' if provided + // Parse 'start_block_height' if hasStartBlockHeight { - var err error - args.StartBlockHeight, err = util.ToUint64(startBlockHeightIn) + result, ok := startBlockHeightIn.(string) + if !ok { + return flow.ZeroID, 0, fmt.Errorf("'start_block_height' must be a string") + } + startBlockHeight, err := util.ToUint64(result) if err != nil { - return args, fmt.Errorf("invalid 'start_block_height': %w", err) + return flow.ZeroID, request.EmptyHeight, fmt.Errorf("invalid 'start_block_height': %w", err) } - } else { - args.StartBlockHeight = request.EmptyHeight + return flow.ZeroID, startBlockHeight, nil } - return args, nil + return flow.ZeroID, request.EmptyHeight, nil } diff --git a/engine/access/rest/websockets/data_providers/blocks_provider_test.go b/engine/access/rest/websockets/data_providers/blocks_provider_test.go index 9e07f9459e9..85136ae5819 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider_test.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider_test.go @@ -5,7 +5,6 @@ import ( "fmt" "strconv" "testing" - "time" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" @@ -15,26 +14,15 @@ import ( accessmock "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" statestreamsmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) const unknownBlockStatus = "unknown_block_status" -type testErrType struct { - name string - arguments models.Arguments - expectedErrorMsg string -} - -// testType represents a valid test scenario for subscribing -type testType struct { - name string - arguments models.Arguments - setupBackend func(sub *statestreamsmock.Subscription) -} - // BlocksProviderSuite is a test suite for testing the block providers functionality. type BlocksProviderSuite struct { suite.Suite @@ -73,7 +61,13 @@ func (s *BlocksProviderSuite) SetupTest() { } s.finalizedBlock = parent - s.factory = NewDataProviderFactory(s.log, nil, s.api) + s.factory = NewDataProviderFactory( + s.log, + nil, + s.api, + flow.Testnet.Chain(), + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) s.Require().NotNil(s.factory) } @@ -189,85 +183,28 @@ func (s *BlocksProviderSuite) validBlockArgumentsTestCases() []testType { // validates that blocks are correctly streamed to the channel and ensures // no unexpected errors occur. func (s *BlocksProviderSuite) TestBlocksDataProvider_HappyPath() { - s.testHappyPath( + testHappyPath( + s.T(), BlocksTopic, + s.factory, s.validBlockArgumentsTestCases(), - func(dataChan chan interface{}, blocks []*flow.Block) { - for _, block := range blocks { + func(dataChan chan interface{}) { + for _, block := range s.blocks { dataChan <- block } }, + s.blocks, s.requireBlock, ) } // requireBlocks ensures that the received block information matches the expected data. -func (s *BlocksProviderSuite) requireBlock(v interface{}, expectedBlock *flow.Block) { +func (s *BlocksProviderSuite) requireBlock(v interface{}, expected interface{}) { + expectedBlock, ok := expected.(*flow.Block) + require.True(s.T(), ok, "unexpected type: %T", v) + actualResponse, ok := v.(*models.BlockMessageResponse) require.True(s.T(), ok, "unexpected response type: %T", v) s.Require().Equal(expectedBlock, actualResponse.Block) } - -// testHappyPath tests a variety of scenarios for data providers in -// happy path scenarios. This function runs parameterized test cases that -// simulate various configurations and verifies that the data provider operates -// as expected without encountering errors. -// -// Arguments: -// - topic: The topic associated with the data provider. -// - tests: A slice of test cases to run, each specifying setup and validation logic. -// - sendData: A function to simulate emitting data into the subscription's data channel. -// - requireFn: A function to validate the output received in the send channel. -func (s *BlocksProviderSuite) testHappyPath( - topic string, - tests []testType, - sendData func(chan interface{}, []*flow.Block), - requireFn func(interface{}, *flow.Block), -) { - for _, test := range tests { - s.Run(test.name, func() { - ctx := context.Background() - send := make(chan interface{}, 10) - - // Create a channel to simulate the subscription's data channel - dataChan := make(chan interface{}) - - // Create a mock subscription and mock the channel - sub := statestreamsmock.NewSubscription(s.T()) - sub.On("Channel").Return((<-chan interface{})(dataChan)) - sub.On("Err").Return(nil) - test.setupBackend(sub) - - // Create the data provider instance - provider, err := s.factory.NewDataProvider(ctx, topic, test.arguments, send) - s.Require().NotNil(provider) - s.Require().NoError(err) - - // Run the provider in a separate goroutine - go func() { - err = provider.Run() - s.Require().NoError(err) - }() - - // Simulate emitting data to the data channel - go func() { - defer close(dataChan) - sendData(dataChan, s.blocks) - }() - - // Collect responses - for _, b := range s.blocks { - unittest.RequireReturnsBefore(s.T(), func() { - v, ok := <-send - s.Require().True(ok, "channel closed while waiting for block %x %v: err: %v", b.Header.Height, b.ID(), sub.Err()) - - requireFn(v, b) - }, time.Second, fmt.Sprintf("timed out waiting for block %d %v", b.Header.Height, b.ID())) - } - - // Ensure the provider is properly closed after the test - provider.Close() - }) - } -} diff --git a/engine/access/rest/websockets/data_providers/events_provider.go b/engine/access/rest/websockets/data_providers/events_provider.go new file mode 100644 index 00000000000..318e8081d2c --- /dev/null +++ b/engine/access/rest/websockets/data_providers/events_provider.go @@ -0,0 +1,185 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/counters" +) + +// eventsArguments contains the arguments required for subscribing to events +type eventsArguments struct { + StartBlockID flow.Identifier // ID of the block to start subscription from + StartBlockHeight uint64 // Height of the block to start subscription from + Filter state_stream.EventFilter // Filter applied to events for a given subscription +} + +// EventsDataProvider is responsible for providing events +type EventsDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + stateStreamApi state_stream.API + + heartbeatInterval uint64 +} + +var _ DataProvider = (*EventsDataProvider)(nil) + +// NewEventsDataProvider creates a new instance of EventsDataProvider. +func NewEventsDataProvider( + ctx context.Context, + logger zerolog.Logger, + stateStreamApi state_stream.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, + heartbeatInterval uint64, +) (*EventsDataProvider, error) { + p := &EventsDataProvider{ + logger: logger.With().Str("component", "events-data-provider").Logger(), + stateStreamApi: stateStreamApi, + heartbeatInterval: heartbeatInterval, + } + + // Initialize arguments passed to the provider. + eventArgs, err := parseEventsArguments(arguments, chain, eventFilterConfig) + if err != nil { + return nil, fmt.Errorf("invalid arguments for events data provider: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, eventArgs), // Set up a subscription to events based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for events and handles responses. +// +// No errors are expected during normal operations. +func (p *EventsDataProvider) Run() error { + return subscription.HandleSubscription(p.subscription, p.handleResponse()) +} + +// handleResponse processes events and sends the formatted response. +// +// No errors are expected during normal operations. +func (p *EventsDataProvider) handleResponse() func(eventsResponse *backend.EventsResponse) error { + blocksSinceLastMessage := uint64(0) + messageIndex := counters.NewMonotonousCounter(0) + + return func(eventsResponse *backend.EventsResponse) error { + // check if there are any events in the response. if not, do not send a message unless the last + // response was more than HeartbeatInterval blocks ago + if len(eventsResponse.Events) == 0 { + blocksSinceLastMessage++ + if blocksSinceLastMessage < p.heartbeatInterval { + return nil + } + blocksSinceLastMessage = 0 + } + + index := messageIndex.Value() + if ok := messageIndex.Set(messageIndex.Value() + 1); !ok { + return fmt.Errorf("message index already incremented to: %d", messageIndex.Value()) + } + + p.send <- &models.EventResponse{ + BlockId: eventsResponse.BlockID.String(), + BlockHeight: strconv.FormatUint(eventsResponse.Height, 10), + BlockTimestamp: eventsResponse.BlockTimestamp, + Events: eventsResponse.Events, + MessageIndex: index, + } + + return nil + } +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *EventsDataProvider) createSubscription(ctx context.Context, args eventsArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.stateStreamApi.SubscribeEventsFromStartBlockID(ctx, args.StartBlockID, args.Filter) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.stateStreamApi.SubscribeEventsFromStartHeight(ctx, args.StartBlockHeight, args.Filter) + } + + return p.stateStreamApi.SubscribeEventsFromLatest(ctx, args.Filter) +} + +// parseEventsArguments validates and initializes the events arguments. +func parseEventsArguments( + arguments models.Arguments, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, +) (eventsArguments, error) { + var args eventsArguments + + // Parse block arguments + startBlockID, startBlockHeight, err := ParseStartBlock(arguments) + if err != nil { + return args, err + } + args.StartBlockID = startBlockID + args.StartBlockHeight = startBlockHeight + + // Parse 'event_types' as a JSON array + var eventTypes parser.EventTypes + if eventTypesIn, ok := arguments["event_types"]; ok && eventTypesIn != "" { + result, ok := eventTypesIn.([]string) + if !ok { + return args, fmt.Errorf("'event_types' must be an array of string") + } + + err := eventTypes.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'event_types': %w", err) + } + } + + // Parse 'addresses' as []string{} + var addresses []string + if addressesIn, ok := arguments["addresses"]; ok && addressesIn != "" { + addresses, ok = addressesIn.([]string) + if !ok { + return args, fmt.Errorf("'addresses' must be an array of string") + } + } + + // Parse 'contracts' as []string{} + var contracts []string + if contractsIn, ok := arguments["contracts"]; ok && contractsIn != "" { + contracts, ok = contractsIn.([]string) + if !ok { + return args, fmt.Errorf("'contracts' must be an array of string") + } + } + + // Initialize the event filter with the parsed arguments + args.Filter, err = state_stream.NewEventFilter(eventFilterConfig, chain, eventTypes.Flow(), addresses, contracts) + if err != nil { + return args, fmt.Errorf("failed to create event filter: %w", err) + } + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/events_provider_test.go b/engine/access/rest/websockets/data_providers/events_provider_test.go new file mode 100644 index 00000000000..4902f3b35a6 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/events_provider_test.go @@ -0,0 +1,295 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + ssmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +// EventsProviderSuite is a test suite for testing the events providers functionality. +type EventsProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *ssmock.API + + chain flow.Chain + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestEventsProviderSuite(t *testing.T) { + suite.Run(t, new(EventsProviderSuite)) +} + +func (s *EventsProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = ssmock.NewAPI(s.T()) + + s.chain = flow.Testnet.Chain() + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + + s.factory = NewDataProviderFactory( + s.log, + s.api, + nil, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(s.factory) +} + +// TestEventsDataProvider_HappyPath tests the behavior of the events data provider +// when it is configured correctly and operating under normal conditions. It +// validates that events are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *EventsProviderSuite) TestEventsDataProvider_HappyPath() { + + expectedEvents := []flow.Event{ + unittest.EventFixture(flow.EventAccountCreated, 0, 0, unittest.IdentifierFixture(), 0), + unittest.EventFixture(flow.EventAccountUpdated, 0, 0, unittest.IdentifierFixture(), 0), + } + + var expectedEventsResponses []backend.EventsResponse + for i := 0; i < len(expectedEvents); i++ { + expectedEventsResponses = append(expectedEventsResponses, backend.EventsResponse{ + Height: s.rootBlock.Header.Height, + BlockID: s.rootBlock.ID(), + Events: expectedEvents, + BlockTimestamp: s.rootBlock.Header.Timestamp, + }) + } + + testHappyPath( + s.T(), + EventsTopic, + s.factory, + s.subscribeEventsDataProviderTestCases(), + func(dataChan chan interface{}) { + for i := 0; i < len(expectedEventsResponses); i++ { + dataChan <- &expectedEventsResponses[i] + } + }, + expectedEventsResponses, + s.requireEvents, + ) +} + +// subscribeEventsDataProviderTestCases generates test cases for events data providers. +func (s *EventsProviderSuite) subscribeEventsDataProviderTestCases() []testType { + return []testType{ + { + name: "SubscribeBlocksFromStartBlockID happy path", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "event_types": []string{"flow.AccountCreated", "flow.AccountUpdated"}, + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeEventsFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeEventsFromStartHeight happy path", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeEventsFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeEventsFromLatest happy path", + arguments: models.Arguments{}, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeEventsFromLatest", + mock.Anything, + mock.Anything, + ).Return(sub).Once() + }, + }, + } +} + +// requireEvents ensures that the received event information matches the expected data. +func (s *EventsProviderSuite) requireEvents(v interface{}, expectedResponse interface{}) { + expectedEventsResponse, ok := expectedResponse.(backend.EventsResponse) + require.True(s.T(), ok, "unexpected type: %T", expectedResponse) + + actualResponse, ok := v.(*models.EventResponse) + require.True(s.T(), ok, "Expected *models.EventResponse, got %T", v) + + s.Require().ElementsMatch(expectedEventsResponse.Events, actualResponse.Events) +} + +// invalidArgumentsTestCases returns a list of test cases with invalid argument combinations +// for testing the behavior of events data providers. Each test case includes a name, +// a set of input arguments, and the expected error message that should be returned. +// +// The test cases cover scenarios such as: +// 1. Supplying both 'start_block_id' and 'start_block_height' simultaneously, which is not allowed. +// 2. Providing invalid 'start_block_id' value. +// 3. Providing invalid 'start_block_height' value. +func invalidArgumentsTestCases() []testErrType { + return []testErrType{ + { + name: "provide both 'start_block_id' and 'start_block_height' arguments", + arguments: models.Arguments{ + "start_block_id": unittest.BlockFixture().ID().String(), + "start_block_height": fmt.Sprintf("%d", unittest.BlockFixture().Header.Height), + }, + expectedErrorMsg: "can only provide either 'start_block_id' or 'start_block_height'", + }, + { + name: "invalid 'start_block_id' argument", + arguments: map[string]interface{}{ + "start_block_id": "invalid_block_id", + }, + expectedErrorMsg: "invalid ID format", + }, + { + name: "invalid 'start_block_height' argument", + arguments: map[string]interface{}{ + "start_block_height": "-1", + }, + expectedErrorMsg: "value must be an unsigned 64 bit integer", + }, + } +} + +// TestEventsDataProvider_InvalidArguments tests the behavior of the event data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Providing both 'start_block_id' and 'start_block_height' simultaneously. +// 2. Invalid 'start_block_id' argument. +// 3. Invalid 'start_block_height' argument. +func (s *EventsProviderSuite) TestEventsDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := EventsTopic + + for _, test := range invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewEventsDataProvider( + ctx, + s.log, + s.api, + topic, + test.arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// TestMessageIndexEventProviderResponse_HappyPath tests that MessageIndex values in response are strictly increasing. +func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() { + ctx := context.Background() + send := make(chan interface{}, 10) + topic := EventsTopic + eventsCount := 4 + + // Create a channel to simulate the subscription's event channel + eventChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(eventChan)) + sub.On("Err").Return(nil) + + s.api.On("SubscribeEventsFromStartBlockID", mock.Anything, mock.Anything, mock.Anything).Return(sub) + + arguments := + map[string]interface{}{ + "start_block_id": s.rootBlock.ID().String(), + } + + // Create the EventsDataProvider instance + provider, err := NewEventsDataProvider( + ctx, + s.log, + s.api, + topic, + arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine to simulate subscription processing + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting events to the event channel + go func() { + defer close(eventChan) // Close the channel when done + + for i := 0; i < eventsCount; i++ { + eventChan <- &backend.EventsResponse{ + Height: s.rootBlock.Header.Height, + } + } + }() + + // Collect responses + var responses []*models.EventResponse + for i := 0; i < eventsCount; i++ { + res := <-send + eventRes, ok := res.(*models.EventResponse) + s.Require().True(ok, "Expected *models.EventResponse, got %T", res) + responses = append(responses, eventRes) + } + + // Verifying that indices are starting from 1 + s.Require().Equal(uint64(0), responses[0].MessageIndex, "Expected MessageIndex to start with 0") + + // Verifying that indices are strictly increasing + for i := 1; i < len(responses); i++ { + prevIndex := responses[i-1].MessageIndex + currentIndex := responses[i].MessageIndex + s.Require().Equal(prevIndex+1, currentIndex, "Expected MessageIndex to increment by 1") + } + + // Ensure the provider is properly closed after the test + provider.Close() +} diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index 72f4a6b7633..ff23708d337 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -9,17 +9,19 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/websockets/models" "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/model/flow" ) // Constants defining various topic names used to specify different types of // data providers. const ( - EventsTopic = "events" - AccountStatusesTopic = "account_statuses" - BlocksTopic = "blocks" - BlockHeadersTopic = "block_headers" - BlockDigestsTopic = "block_digests" - TransactionStatusesTopic = "transaction_statuses" + EventsTopic = "events" + AccountStatusesTopic = "account_statuses" + BlocksTopic = "blocks" + BlockHeadersTopic = "block_headers" + BlockDigestsTopic = "block_digests" + TransactionStatusesTopic = "transaction_statuses" + SendTransactionStatusesTopic = "send_transaction_statuses" ) // DataProviderFactory defines an interface for creating data providers @@ -43,6 +45,10 @@ type DataProviderFactoryImpl struct { stateStreamApi state_stream.API accessApi access.API + + chain flow.Chain + eventFilterConfig state_stream.EventFilterConfig + heartbeatInterval uint64 } // NewDataProviderFactory creates a new DataProviderFactory @@ -56,11 +62,17 @@ func NewDataProviderFactory( logger zerolog.Logger, stateStreamApi state_stream.API, accessApi access.API, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, + heartbeatInterval uint64, ) *DataProviderFactoryImpl { return &DataProviderFactoryImpl{ - logger: logger, - stateStreamApi: stateStreamApi, - accessApi: accessApi, + logger: logger, + stateStreamApi: stateStreamApi, + accessApi: accessApi, + chain: chain, + eventFilterConfig: eventFilterConfig, + heartbeatInterval: heartbeatInterval, } } @@ -87,11 +99,14 @@ func (s *DataProviderFactoryImpl) NewDataProvider( return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) case BlockDigestsTopic: return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) - // TODO: Implemented handlers for each topic should be added in respective case - case EventsTopic, - AccountStatusesTopic, - TransactionStatusesTopic: - return nil, fmt.Errorf(`topic "%s" not implemented yet`, topic) + case EventsTopic: + return NewEventsDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + case AccountStatusesTopic: + return NewAccountStatusesDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + case TransactionStatusesTopic: + return NewTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + case SendTransactionStatusesTopic: + return NewSendTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) default: return nil, fmt.Errorf("unsupported topic \"%s\"", topic) } diff --git a/engine/access/rest/websockets/data_providers/factory_test.go b/engine/access/rest/websockets/data_providers/factory_test.go index 2ed2b075d0c..f18455b7edd 100644 --- a/engine/access/rest/websockets/data_providers/factory_test.go +++ b/engine/access/rest/websockets/data_providers/factory_test.go @@ -11,7 +11,9 @@ import ( accessmock "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" statestreammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -43,7 +45,16 @@ func (s *DataProviderFactorySuite) SetupTest() { s.ctx = context.Background() s.ch = make(chan interface{}) - s.factory = NewDataProviderFactory(log, s.stateStreamApi, s.accessApi) + chain := flow.Testnet.Chain() + + s.factory = NewDataProviderFactory( + log, + s.stateStreamApi, + s.accessApi, + chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) s.Require().NotNil(s.factory) } @@ -99,6 +110,50 @@ func (s *DataProviderFactorySuite) TestSupportedTopics() { s.accessApi.AssertExpectations(s.T()) }, }, + { + name: "events topic", + topic: EventsTopic, + arguments: models.Arguments{}, + setupSubscription: func() { + s.setupSubscription(s.stateStreamApi.On("SubscribeEventsFromLatest", mock.Anything, mock.Anything)) + }, + assertExpectations: func() { + s.stateStreamApi.AssertExpectations(s.T()) + }, + }, + { + name: "account statuses topic", + topic: AccountStatusesTopic, + arguments: models.Arguments{}, + setupSubscription: func() { + s.setupSubscription(s.stateStreamApi.On("SubscribeAccountStatusesFromLatestBlock", mock.Anything, mock.Anything)) + }, + assertExpectations: func() { + s.stateStreamApi.AssertExpectations(s.T()) + }, + }, + { + name: "transaction statuses topic", + topic: TransactionStatusesTopic, + arguments: models.Arguments{}, + setupSubscription: func() { + s.setupSubscription(s.accessApi.On("SubscribeTransactionStatusesFromLatest", mock.Anything, mock.Anything, mock.Anything)) + }, + assertExpectations: func() { + s.stateStreamApi.AssertExpectations(s.T()) + }, + }, + { + name: "send transaction statuses topic", + topic: SendTransactionStatusesTopic, + arguments: models.Arguments{}, + setupSubscription: func() { + s.setupSubscription(s.accessApi.On("SendAndSubscribeTransactionStatuses", mock.Anything, mock.Anything, mock.Anything)) + }, + assertExpectations: func() { + s.stateStreamApi.AssertExpectations(s.T()) + }, + }, } for _, test := range testCases { diff --git a/engine/access/rest/websockets/data_providers/send_transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/send_transaction_statuses_provider.go new file mode 100644 index 00000000000..c2ad0ca5937 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/send_transaction_statuses_provider.go @@ -0,0 +1,235 @@ +package data_providers + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/util" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/counters" + + "github.com/onflow/flow/protobuf/go/flow/entities" +) + +// sendTransactionStatusesArguments contains the arguments required for sending tx and subscribing to transaction statuses +type sendTransactionStatusesArguments struct { + Transaction flow.TransactionBody // The transaction body to be sent and monitored. +} + +type SendTransactionStatusesDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + api access.API +} + +var _ DataProvider = (*SendTransactionStatusesDataProvider)(nil) + +func NewSendTransactionStatusesDataProvider( + ctx context.Context, + logger zerolog.Logger, + api access.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, +) (*SendTransactionStatusesDataProvider, error) { + p := &SendTransactionStatusesDataProvider{ + logger: logger.With().Str("component", "send-transaction-statuses-data-provider").Logger(), + api: api, + } + + // Initialize arguments passed to the provider. + sendTxStatusesArgs, err := parseSendTransactionStatusesArguments(arguments) + if err != nil { + return nil, fmt.Errorf("invalid arguments for send tx statuses data provider: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, sendTxStatusesArgs), // Set up a subscription to tx statuses based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for events and handles responses. +// +// No errors are expected during normal operations. +func (p *SendTransactionStatusesDataProvider) Run() error { + return subscription.HandleSubscription(p.subscription, p.handleResponse()) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *SendTransactionStatusesDataProvider) createSubscription( + ctx context.Context, + args sendTransactionStatusesArguments, +) subscription.Subscription { + return p.api.SendAndSubscribeTransactionStatuses(ctx, &args.Transaction, entities.EventEncodingVersion_JSON_CDC_V0) +} + +// handleResponse processes an account statuses and sends the formatted response. +// +// No errors are expected during normal operations. +func (p *SendTransactionStatusesDataProvider) handleResponse() func(txResults []*access.TransactionResult) error { + + messageIndex := counters.NewMonotonousCounter(0) + + return func(txResults []*access.TransactionResult) error { + + if ok := messageIndex.Set(messageIndex.Value() + 1); !ok { + return status.Errorf(codes.Internal, "message index already incremented to %d", messageIndex.Value()) + } + index := messageIndex.Value() + + p.send <- &models.TransactionStatusesResponse{ + TransactionResults: txResults, + MessageIndex: index, + } + + return nil + } + +} + +// parseAccountStatusesArguments validates and initializes the account statuses arguments. +func parseSendTransactionStatusesArguments( + arguments models.Arguments, +) (sendTransactionStatusesArguments, error) { + var args sendTransactionStatusesArguments + var tx flow.TransactionBody + + if scriptIn, ok := arguments["script"]; ok && scriptIn != "" { + result, ok := scriptIn.(string) + if !ok { + return args, fmt.Errorf("'script' must be a string") + } + + script, err := util.FromBase64(result) + if err != nil { + return args, fmt.Errorf("invalid 'script': %w", err) + } + + tx.Script = script + } + + if argumentsIn, ok := arguments["arguments"]; ok && argumentsIn != "" { + result, ok := argumentsIn.([]string) + if !ok { + return args, fmt.Errorf("'arguments' must be a []string type") + } + + var argumentsData [][]byte + for _, arg := range result { + argument, err := util.FromBase64(arg) + if err != nil { + return args, fmt.Errorf("invalid 'arguments': %w", err) + } + + argumentsData = append(argumentsData, argument) + } + + tx.Arguments = argumentsData + } + + if referenceBlockIDIn, ok := arguments["reference_block_id"]; ok && referenceBlockIDIn != "" { + result, ok := referenceBlockIDIn.(string) + if !ok { + return args, fmt.Errorf("'reference_block_id' must be a string") + } + + var referenceBlockID parser.ID + err := referenceBlockID.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'reference_block_id': %w", err) + } + + tx.ReferenceBlockID = referenceBlockID.Flow() + } + + if gasLimitIn, ok := arguments["gas_limit"]; ok && gasLimitIn != "" { + result, ok := gasLimitIn.(string) + if !ok { + return args, fmt.Errorf("'gas_limit' must be a string") + } + + gasLimit, err := util.ToUint64(result) + if err != nil { + return args, fmt.Errorf("invalid 'gas_limit': %w", err) + } + tx.GasLimit = gasLimit + } + + if payerIn, ok := arguments["payer"]; ok && payerIn != "" { + result, ok := payerIn.(string) + if !ok { + return args, fmt.Errorf("'payerIn' must be a string") + } + + payerAddr, err := flow.StringToAddress(result) + if err != nil { + return args, fmt.Errorf("invalid 'payer': %w", err) + } + tx.Payer = payerAddr + } + + if proposalKeyIn, ok := arguments["proposal_key"]; ok && proposalKeyIn != "" { + proposalKey, ok := proposalKeyIn.(flow.ProposalKey) + if !ok { + return args, fmt.Errorf("'proposal_key' must be a object (ProposalKey)") + } + + tx.ProposalKey = proposalKey + } + + if authorizersIn, ok := arguments["authorizers"]; ok && authorizersIn != "" { + result, ok := authorizersIn.([]string) + if !ok { + return args, fmt.Errorf("'authorizers' must be a []string type") + } + + var authorizersData []flow.Address + for _, auth := range result { + authorizer, err := flow.StringToAddress(auth) + if err != nil { + return args, fmt.Errorf("invalid 'authorizers': %w", err) + } + + authorizersData = append(authorizersData, authorizer) + } + + tx.Authorizers = authorizersData + } + + if payloadSignaturesIn, ok := arguments["payload_signatures"]; ok && payloadSignaturesIn != "" { + payloadSignatures, ok := payloadSignaturesIn.([]flow.TransactionSignature) + if !ok { + return args, fmt.Errorf("'payload_signatures' must be an array of objects (TransactionSignature)") + } + + tx.PayloadSignatures = payloadSignatures + } + + if envelopeSignaturesIn, ok := arguments["envelope_signatures"]; ok && envelopeSignaturesIn != "" { + envelopeSignatures, ok := envelopeSignaturesIn.([]flow.TransactionSignature) + if !ok { + return args, fmt.Errorf("'envelope_signatures' must be an array of objects (TransactionSignature)") + } + + tx.EnvelopeSignatures = envelopeSignatures + } + args.Transaction = tx + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/send_transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/send_transaction_statuses_provider_test.go new file mode 100644 index 00000000000..ea617265d8f --- /dev/null +++ b/engine/access/rest/websockets/data_providers/send_transaction_statuses_provider_test.go @@ -0,0 +1,233 @@ +package data_providers + +import ( + "context" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + accessmock "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + ssmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" + + "github.com/onflow/flow/protobuf/go/flow/entities" +) + +type SendTransactionStatusesProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *accessmock.API + + chain flow.Chain + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestNewSendTransactionStatusesDataProvider(t *testing.T) { + suite.Run(t, new(SendTransactionStatusesProviderSuite)) +} + +func (s *SendTransactionStatusesProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = accessmock.NewAPI(s.T()) + + s.chain = flow.Testnet.Chain() + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + + s.factory = NewDataProviderFactory( + s.log, + nil, + s.api, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(s.factory) +} + +// TestSendTransactionStatusesDataProvider_HappyPath tests the behavior of the send transaction statuses data provider +// when it is configured correctly and operating under normal conditions. It +// validates that tx statuses are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *TransactionStatusesProviderSuite) TestSendTransactionStatusesDataProvider_HappyPath() { + + sendTxStatutesTestCases := []testType{ + { + name: "SubscribeTransactionStatusesFromStartBlockID happy path", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SendAndSubscribeTransactionStatuses", + mock.Anything, + mock.Anything, + entities.EventEncodingVersion_JSON_CDC_V0, + ).Return(sub).Once() + }, + }, + } + + expectedResponse := expectedTransactionStatusesResponse(s.rootBlock) + + testHappyPath( + s.T(), + SendTransactionStatusesTopic, + s.factory, + sendTxStatutesTestCases, + func(dataChan chan interface{}) { + for i := 0; i < len(expectedResponse); i++ { + dataChan <- expectedResponse[i] + } + }, + expectedResponse, + s.requireTransactionStatuses, + ) + +} + +// TestSendTransactionStatusesDataProvider_InvalidArguments tests the behavior of the send transaction statuses data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Invalid 'script' type. +// 2. Invalid 'script' value. +// 3. Invalid 'arguments' type. +// 4. Invalid 'arguments' value. +// 5. Invalid 'reference_block_id' value. +// 6. Invalid 'gas_limit' value. +// 7. Invalid 'payer' value. +// 8. Invalid 'proposal_key' value. +// 9. Invalid 'authorizers' value. +// 10. Invalid 'payload_signatures' value. +// 11. Invalid 'envelope_signatures' value. +func (s *SendTransactionStatusesProviderSuite) TestSendTransactionStatusesDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := SendTransactionStatusesTopic + + for _, test := range invalidSendTransactionStatusesArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewSendTransactionStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + test.arguments, + send, + ) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// invalidSendTransactionStatusesArgumentsTestCases returns a list of test cases with invalid argument combinations +// for testing the behavior of send transaction statuses data providers. Each test case includes a name, +// a set of input arguments, and the expected error message that should be returned. +// +// The test cases cover scenarios such as: +// 1. Providing invalid 'script' type. +// 2. Providing invalid 'script' value. +// 3. Providing invalid 'arguments' type. +// 4. Providing invalid 'arguments' value. +// 5. Providing invalid 'reference_block_id' value. +// 6. Providing invalid 'gas_limit' value. +// 7. Providing invalid 'payer' value. +// 8. Providing invalid 'proposal_key' value. +// 9. Providing invalid 'authorizers' value. +// 10. Providing invalid 'payload_signatures' value. +// 11. Providing invalid 'envelope_signatures' value. +func invalidSendTransactionStatusesArgumentsTestCases() []testErrType { + return []testErrType{ + { + name: "invalid 'script' argument type", + arguments: map[string]interface{}{ + "script": 0, + }, + expectedErrorMsg: "'script' must be a string", + }, + { + name: "invalid 'script' argument", + arguments: map[string]interface{}{ + "script": "invalid_script", + }, + expectedErrorMsg: "invalid 'script': illegal base64 data ", + }, + { + name: "invalid 'arguments' type", + arguments: map[string]interface{}{ + "arguments": 0, + }, + expectedErrorMsg: "'arguments' must be a []string type", + }, + { + name: "invalid 'arguments' argument", + arguments: map[string]interface{}{ + "arguments": []string{"invalid_base64_1", "invalid_base64_2"}, + }, + expectedErrorMsg: "invalid 'arguments'", + }, + { + name: "invalid 'reference_block_id' argument", + arguments: map[string]interface{}{ + "reference_block_id": "invalid_reference_block_id", + }, + expectedErrorMsg: "invalid ID format", + }, + { + name: "invalid 'gas_limit' argument", + arguments: map[string]interface{}{ + "gas_limit": "-1", + }, + expectedErrorMsg: "value must be an unsigned 64 bit integer", + }, + { + name: "invalid 'payer' argument", + arguments: map[string]interface{}{ + "payer": "invalid_payer", + }, + expectedErrorMsg: "invalid 'payer': can not decode hex string", + }, + { + name: "invalid 'proposal_key' argument", + arguments: map[string]interface{}{ + "proposal_key": "invalid ProposalKey object", + }, + expectedErrorMsg: "'proposal_key' must be a object (ProposalKey)", + }, + { + name: "invalid 'authorizers' argument", + arguments: map[string]interface{}{ + "authorizers": []string{"invalid_base64_1", "invalid_base64_2"}, + }, + expectedErrorMsg: "invalid 'authorizers': can not decode hex string", + }, + { + name: "invalid 'payload_signatures' argument", + arguments: map[string]interface{}{ + "payload_signatures": "invalid TransactionSignature array", + }, + expectedErrorMsg: "'payload_signatures' must be an array of objects (TransactionSignature)", + }, + { + name: "invalid 'envelope_signatures' argument", + arguments: map[string]interface{}{ + "envelope_signatures": "invalid TransactionSignature array", + }, + expectedErrorMsg: "'envelope_signatures' must be an array of objects (TransactionSignature)", + }, + } +} diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go new file mode 100644 index 00000000000..3e6fa0cc928 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go @@ -0,0 +1,143 @@ +package data_providers + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/counters" + + "github.com/onflow/flow/protobuf/go/flow/entities" +) + +// transactionStatusesArguments contains the arguments required for subscribing to transaction statuses +type transactionStatusesArguments struct { + TxID flow.Identifier // ID of the transaction to monitor. + StartBlockID flow.Identifier // ID of the block to start subscription from + StartBlockHeight uint64 // Height of the block to start subscription from +} + +// TransactionStatusesDataProvider is responsible for providing tx statuses +type TransactionStatusesDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + api access.API +} + +var _ DataProvider = (*TransactionStatusesDataProvider)(nil) + +func NewTransactionStatusesDataProvider( + ctx context.Context, + logger zerolog.Logger, + api access.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, +) (*TransactionStatusesDataProvider, error) { + p := &TransactionStatusesDataProvider{ + logger: logger.With().Str("component", "transaction-statuses-data-provider").Logger(), + api: api, + } + + // Initialize arguments passed to the provider. + txStatusesArgs, err := parseTransactionStatusesArguments(arguments) + if err != nil { + return nil, fmt.Errorf("invalid arguments for tx statuses data provider: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, txStatusesArgs), // Set up a subscription to tx statuses based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for events and handles responses. +// +// No errors are expected during normal operations. +func (p *TransactionStatusesDataProvider) Run() error { + return subscription.HandleSubscription(p.subscription, p.handleResponse()) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *TransactionStatusesDataProvider) createSubscription( + ctx context.Context, + args transactionStatusesArguments, +) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.api.SubscribeTransactionStatusesFromStartBlockID(ctx, args.TxID, args.StartBlockID, entities.EventEncodingVersion_JSON_CDC_V0) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.api.SubscribeTransactionStatusesFromStartHeight(ctx, args.TxID, args.StartBlockHeight, entities.EventEncodingVersion_JSON_CDC_V0) + } + + return p.api.SubscribeTransactionStatusesFromLatest(ctx, args.TxID, entities.EventEncodingVersion_JSON_CDC_V0) +} + +// handleResponse processes an account statuses and sends the formatted response. +// +// No errors are expected during normal operations. +func (p *TransactionStatusesDataProvider) handleResponse() func(txResults []*access.TransactionResult) error { + messageIndex := counters.NewMonotonousCounter(0) + + return func(txResults []*access.TransactionResult) error { + + index := messageIndex.Value() + if ok := messageIndex.Set(messageIndex.Value() + 1); !ok { + return status.Errorf(codes.Internal, "message index already incremented to %d", messageIndex.Value()) + } + + p.send <- &models.TransactionStatusesResponse{ + TransactionResults: txResults, + MessageIndex: index, + } + + return nil + } +} + +// parseAccountStatusesArguments validates and initializes the account statuses arguments. +func parseTransactionStatusesArguments( + arguments models.Arguments, +) (transactionStatusesArguments, error) { + var args transactionStatusesArguments + + // Parse block arguments + startBlockID, startBlockHeight, err := ParseStartBlock(arguments) + if err != nil { + return args, err + } + args.StartBlockID = startBlockID + args.StartBlockHeight = startBlockHeight + + if txIDIn, ok := arguments["tx_id"]; ok && txIDIn != "" { + result, ok := txIDIn.(string) + if !ok { + return args, fmt.Errorf("'tx_id' must be a string") + } + var txID parser.ID + err := txID.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'tx_id': %w", err) + } + args.TxID = txID.Flow() + } + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go new file mode 100644 index 00000000000..bfb81f82f81 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go @@ -0,0 +1,319 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/access" + accessmock "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + ssmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" + + "github.com/onflow/flow/protobuf/go/flow/entities" +) + +type TransactionStatusesProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *accessmock.API + + chain flow.Chain + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestNewTransactionStatusesDataProvider(t *testing.T) { + suite.Run(t, new(TransactionStatusesProviderSuite)) +} + +func (s *TransactionStatusesProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = accessmock.NewAPI(s.T()) + + s.chain = flow.Testnet.Chain() + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + + s.factory = NewDataProviderFactory( + s.log, + nil, + s.api, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(s.factory) +} + +// TestTransactionStatusesDataProvider_HappyPath tests the behavior of the transaction statuses data provider +// when it is configured correctly and operating under normal conditions. It +// validates that tx statuses are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *TransactionStatusesProviderSuite) TestTransactionStatusesDataProvider_HappyPath() { + expectedResponse := expectedTransactionStatusesResponse(s.rootBlock) + + testHappyPath( + s.T(), + TransactionStatusesTopic, + s.factory, + s.subscribeTransactionStatusesDataProviderTestCases(), + func(dataChan chan interface{}) { + for i := 0; i < len(expectedResponse); i++ { + dataChan <- expectedResponse[i] + } + }, + expectedResponse, + s.requireTransactionStatuses, + ) +} + +func (s *TransactionStatusesProviderSuite) subscribeTransactionStatusesDataProviderTestCases() []testType { + return []testType{ + { + name: "SubscribeTransactionStatusesFromStartBlockID happy path", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeTransactionStatusesFromStartBlockID", + mock.Anything, + mock.Anything, + s.rootBlock.ID(), + entities.EventEncodingVersion_JSON_CDC_V0, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeTransactionStatusesFromStartHeight happy path", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeTransactionStatusesFromStartHeight", + mock.Anything, + mock.Anything, + s.rootBlock.Header.Height, + entities.EventEncodingVersion_JSON_CDC_V0, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeTransactionStatusesFromLatest happy path", + arguments: models.Arguments{}, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeTransactionStatusesFromLatest", + mock.Anything, + mock.Anything, + entities.EventEncodingVersion_JSON_CDC_V0, + ).Return(sub).Once() + }, + }, + } +} + +// requireTransactionStatuses ensures that the received transaction statuses information matches the expected data. +func (s *TransactionStatusesProviderSuite) requireTransactionStatuses( + v interface{}, + expectedResponse interface{}, +) { + expectedAccountStatusesResponse, ok := expectedResponse.([]*access.TransactionResult) + require.True(s.T(), ok, "unexpected type: %T", expectedResponse) + + actualResponse, ok := v.(*models.TransactionStatusesResponse) + require.True(s.T(), ok, "Expected *models.TransactionStatusesResponse, got %T", v) + + s.Require().ElementsMatch(expectedAccountStatusesResponse, actualResponse.TransactionResults) +} + +// TestTransactionStatusesDataProvider_InvalidArguments tests the behavior of the transaction statuses data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Invalid 'tx_id' argument. +// 2. Invalid 'start_block_id' argument. +func (s *TransactionStatusesProviderSuite) TestTransactionStatusesDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := TransactionStatusesTopic + + for _, test := range invalidTransactionStatusesArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewTransactionStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + test.arguments, + send, + ) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// invalidTransactionStatusesArgumentsTestCases returns a list of test cases with invalid argument combinations +// for testing the behavior of transaction statuses data providers. Each test case includes a name, +// a set of input arguments, and the expected error message that should be returned. +// +// The test cases cover scenarios such as: +// 1. Providing both 'start_block_id' and 'start_block_height' simultaneously. +// 2. Providing invalid 'tx_id' value. +// 3. Providing invalid 'start_block_id' value. +// 4. Invalid 'start_block_id' argument. +func invalidTransactionStatusesArgumentsTestCases() []testErrType { + return []testErrType{ + { + name: "provide both 'start_block_id' and 'start_block_height' arguments", + arguments: models.Arguments{ + "start_block_id": unittest.BlockFixture().ID().String(), + "start_block_height": fmt.Sprintf("%d", unittest.BlockFixture().Header.Height), + }, + expectedErrorMsg: "can only provide either 'start_block_id' or 'start_block_height'", + }, + { + name: "invalid 'tx_id' argument", + arguments: map[string]interface{}{ + "tx_id": "invalid_tx_id", + }, + expectedErrorMsg: "invalid ID format", + }, + { + name: "invalid 'start_block_id' argument", + arguments: map[string]interface{}{ + "start_block_id": "invalid_block_id", + }, + expectedErrorMsg: "invalid ID format", + }, + { + name: "invalid 'start_block_height' argument", + arguments: map[string]interface{}{ + "start_block_height": "-1", + }, + expectedErrorMsg: "value must be an unsigned 64 bit integer", + }, + } +} + +// TestMessageIndexTransactionStatusesProviderResponse_HappyPath tests that MessageIndex values in response are strictly increasing. +func (s *TransactionStatusesProviderSuite) TestMessageIndexTransactionStatusesProviderResponse_HappyPath() { + ctx := context.Background() + send := make(chan interface{}, 10) + topic := TransactionStatusesTopic + txStatusesCount := 4 + + // Create a channel to simulate the subscription's account statuses channel + txStatusesChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(txStatusesChan)) + sub.On("Err").Return(nil) + + s.api.On( + "SubscribeTransactionStatusesFromStartBlockID", + mock.Anything, + mock.Anything, + mock.Anything, + entities.EventEncodingVersion_JSON_CDC_V0, + ).Return(sub) + + arguments := + map[string]interface{}{ + "start_block_id": s.rootBlock.ID().String(), + } + + // Create the TransactionStatusesDataProvider instance + provider, err := NewTransactionStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + arguments, + send, + ) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine to simulate subscription processing + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting data to the еч statuses channel + go func() { + defer close(txStatusesChan) // Close the channel when done + + for i := 0; i < txStatusesCount; i++ { + txStatusesChan <- []*access.TransactionResult{} + } + }() + + // Collect responses + var responses []*models.TransactionStatusesResponse + for i := 0; i < txStatusesCount; i++ { + res := <-send + txStatusesRes, ok := res.(*models.TransactionStatusesResponse) + s.Require().True(ok, "Expected *models.TransactionStatusesResponse, got %T", res) + responses = append(responses, txStatusesRes) + } + + // Verifying that indices are starting from 0 + s.Require().Equal(uint64(0), responses[0].MessageIndex, "Expected MessageIndex to start with 0") + + // Verifying that indices are strictly increasing + for i := 1; i < len(responses); i++ { + prevIndex := responses[i-1].MessageIndex + currentIndex := responses[i].MessageIndex + s.Require().Equal(prevIndex+1, currentIndex, "Expected MessageIndex to increment by 1") + } + + // Ensure the provider is properly closed after the test + provider.Close() +} + +func expectedTransactionStatusesResponse(block flow.Block) [][]*access.TransactionResult { + id := unittest.IdentifierFixture() + cid := unittest.IdentifierFixture() + txr := access.TransactionResult{ + Status: flow.TransactionStatusSealed, + StatusCode: 10, + Events: []flow.Event{ + unittest.EventFixture(flow.EventAccountCreated, 1, 0, id, 200), + }, + ErrorMessage: "", + BlockID: block.ID(), + CollectionID: cid, + BlockHeight: block.Header.Height, + } + + var expectedTxStatusesResponses [][]*access.TransactionResult + var expectedTxResultsResponses []*access.TransactionResult + + for i := 0; i < 2; i++ { + expectedTxResultsResponses = append(expectedTxResultsResponses, &txr) + expectedTxStatusesResponses = append(expectedTxStatusesResponses, expectedTxResultsResponses) + } + + return expectedTxStatusesResponses +} diff --git a/engine/access/rest/websockets/data_providers/utittest.go b/engine/access/rest/websockets/data_providers/utittest.go new file mode 100644 index 00000000000..8ade7c127fc --- /dev/null +++ b/engine/access/rest/websockets/data_providers/utittest.go @@ -0,0 +1,96 @@ +package data_providers + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + statestreamsmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/utils/unittest" +) + +// testType represents a valid test scenario for subscribing +type testType struct { + name string + arguments models.Arguments + setupBackend func(sub *statestreamsmock.Subscription) +} + +// testErrType represents an error cases for subscribing +type testErrType struct { + name string + arguments models.Arguments + expectedErrorMsg string +} + +// testHappyPath tests a variety of scenarios for data providers in +// happy path scenarios. This function runs parameterized test cases that +// simulate various configurations and verifies that the data provider operates +// as expected without encountering errors. +// +// Arguments: +// - topic: The topic associated with the data provider. +// - factory: A factory for creating data provider instance. +// - tests: A slice of test cases to run, each specifying setup and validation logic. +// - sendData: A function to simulate emitting data into the subscription's data channel. +// - expectedResponses: An expected responses to validate the received output. +// - requireFn: A function to validate the output received in the send channel. +func testHappyPath[T any]( + t *testing.T, + topic string, + factory *DataProviderFactoryImpl, + tests []testType, + sendData func(chan interface{}), + expectedResponses []T, + requireFn func(interface{}, interface{}), +) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + send := make(chan interface{}, 10) + + // Create a channel to simulate the subscription's data channel + dataChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := statestreamsmock.NewSubscription(t) + sub.On("Channel").Return((<-chan interface{})(dataChan)) + sub.On("Err").Return(nil) + test.setupBackend(sub) + + // Create the data provider instance + provider, err := factory.NewDataProvider(ctx, topic, test.arguments, send) + require.NotNil(t, provider) + require.NoError(t, err) + + // Run the provider in a separate goroutine + go func() { + err = provider.Run() + require.NoError(t, err) + }() + + // Simulate emitting data to the data channel + go func() { + defer close(dataChan) + sendData(dataChan) + }() + + // Collect responses + for i, expected := range expectedResponses { + unittest.RequireReturnsBefore(t, func() { + v, ok := <-send + require.True(t, ok, "channel closed while waiting for response %v: err: %v", expected, sub.Err()) + + requireFn(v, expected) + }, time.Second, fmt.Sprintf("timed out waiting for response %d %v", i, expected)) + } + + // Ensure the provider is properly closed after the test + provider.Close() + }) + } +} diff --git a/engine/access/rest/websockets/legacy/request/subscribe_events.go b/engine/access/rest/websockets/legacy/request/subscribe_events.go index 1110d3582d4..9e53e7c5fca 100644 --- a/engine/access/rest/websockets/legacy/request/subscribe_events.go +++ b/engine/access/rest/websockets/legacy/request/subscribe_events.go @@ -81,7 +81,7 @@ func (g *SubscribeEvents) Parse( g.StartHeight = 0 } - var eventTypes request.EventTypes + var eventTypes parser.EventTypes err = eventTypes.Parse(rawTypes) if err != nil { return err diff --git a/engine/access/rest/websockets/models/account_models.go b/engine/access/rest/websockets/models/account_models.go new file mode 100644 index 00000000000..fdb6826b4f1 --- /dev/null +++ b/engine/access/rest/websockets/models/account_models.go @@ -0,0 +1,11 @@ +package models + +import "github.com/onflow/flow-go/model/flow" + +// AccountStatusesResponse is the response message for 'events' topic. +type AccountStatusesResponse struct { + BlockID string `json:"blockID"` + Height string `json:"height"` + AccountEvents map[string]flow.EventsList `json:"account_events"` + MessageIndex uint64 `json:"message_index"` +} diff --git a/engine/access/rest/websockets/models/event_models.go b/engine/access/rest/websockets/models/event_models.go new file mode 100644 index 00000000000..0659cbc6937 --- /dev/null +++ b/engine/access/rest/websockets/models/event_models.go @@ -0,0 +1,16 @@ +package models + +import ( + "time" + + "github.com/onflow/flow-go/model/flow" +) + +// EventResponse is the response message for 'events' topic. +type EventResponse struct { + BlockId string `json:"block_id"` + BlockHeight string `json:"block_height"` + BlockTimestamp time.Time `json:"block_timestamp"` + Events []flow.Event `json:"events"` + MessageIndex uint64 `json:"message_index"` +} diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe.go index 95ad17e3708..03b37aee5f1 100644 --- a/engine/access/rest/websockets/models/subscribe.go +++ b/engine/access/rest/websockets/models/subscribe.go @@ -1,6 +1,6 @@ package models -type Arguments map[string]string +type Arguments map[string]interface{} // SubscribeMessageRequest represents a request to subscribe to a topic. type SubscribeMessageRequest struct { diff --git a/engine/access/rest/websockets/models/tx_statuses_model.go b/engine/access/rest/websockets/models/tx_statuses_model.go new file mode 100644 index 00000000000..32754a06603 --- /dev/null +++ b/engine/access/rest/websockets/models/tx_statuses_model.go @@ -0,0 +1,11 @@ +package models + +import ( + "github.com/onflow/flow-go/access" +) + +// TransactionStatusesResponse is the response message for 'events' topic. +type TransactionStatusesResponse struct { + TransactionResults []*access.TransactionResult `json:"transaction_results"` + MessageIndex uint64 `json:"message_index"` +} diff --git a/engine/access/rpc/backend/backend.go b/engine/access/rpc/backend/backend.go index d4666af9529..5ed2232b87b 100644 --- a/engine/access/rpc/backend/backend.go +++ b/engine/access/rpc/backend/backend.go @@ -260,6 +260,7 @@ func New(params Params) (*Backend, error) { executionResults: params.ExecutionResults, subscriptionHandler: params.SubscriptionHandler, blockTracker: params.BlockTracker, + sendTransaction: b.SendTransaction, } retry.SetBackend(b) diff --git a/engine/access/rpc/backend/backend_stream_transactions.go b/engine/access/rpc/backend/backend_stream_transactions.go index a82b365240e..ae6678e0a8b 100644 --- a/engine/access/rpc/backend/backend_stream_transactions.go +++ b/engine/access/rpc/backend/backend_stream_transactions.go @@ -21,6 +21,9 @@ import ( "github.com/onflow/flow/protobuf/go/flow/entities" ) +// sendTransaction defines a function type for sending a transaction. +type sendTransaction func(ctx context.Context, tx *flow.TransactionBody) error + // backendSubscribeTransactions handles transaction subscriptions. type backendSubscribeTransactions struct { txLocalDataProvider *TransactionsLocalDataProvider @@ -30,38 +33,121 @@ type backendSubscribeTransactions struct { subscriptionHandler *subscription.SubscriptionHandler blockTracker subscription.BlockTracker + sendTransaction sendTransaction } -// TransactionSubscriptionMetadata holds data representing the status state for each transaction subscription. -type TransactionSubscriptionMetadata struct { +// transactionSubscriptionMetadata holds data representing the status state for each transaction subscription. +type transactionSubscriptionMetadata struct { *access.TransactionResult txReferenceBlockID flow.Identifier blockWithTx *flow.Header txExecuted bool eventEncodingVersion entities.EventEncodingVersion + shouldTriggerPending bool } -// SubscribeTransactionStatuses subscribes to transaction status changes starting from the transaction reference block ID. -// If invalid tx parameters will be supplied SubscribeTransactionStatuses will return a failed subscription. -func (b *backendSubscribeTransactions) SubscribeTransactionStatuses( +// SendAndSubscribeTransactionStatuses sends a transaction and subscribes to its status updates. +// It starts monitoring the status from the transaction's reference block ID. +// If the transaction cannot be sent or an error occurs during subscription creation, a failed subscription is returned. +func (b *backendSubscribeTransactions) SendAndSubscribeTransactionStatuses( ctx context.Context, tx *flow.TransactionBody, requiredEventEncodingVersion entities.EventEncodingVersion, ) subscription.Subscription { - nextHeight, err := b.blockTracker.GetStartHeightFromBlockID(tx.ReferenceBlockID) + if err := b.sendTransaction(ctx, tx); err != nil { + b.log.Error().Err(err).Str("tx_id", tx.ID().String()).Msg("failed to send transaction") + return subscription.NewFailedSubscription(err, "failed to send transaction") + } + + return b.createSubscription(ctx, tx.ID(), tx.ReferenceBlockID, 0, tx.ReferenceBlockID, requiredEventEncodingVersion, true) +} + +// SubscribeTransactionStatusesFromStartHeight subscribes to the status updates of a transaction. +// Monitoring starts from the specified block height. +// If the block height cannot be determined or an error occurs during subscription creation, a failed subscription is returned. +func (b *backendSubscribeTransactions) SubscribeTransactionStatusesFromStartHeight( + ctx context.Context, + txID flow.Identifier, + startHeight uint64, + requiredEventEncodingVersion entities.EventEncodingVersion, +) subscription.Subscription { + return b.createSubscription(ctx, txID, flow.ZeroID, startHeight, flow.ZeroID, requiredEventEncodingVersion, false) +} + +// SubscribeTransactionStatusesFromStartBlockID subscribes to the status updates of a transaction. +// Monitoring starts from the specified block ID. +// If the block ID cannot be determined or an error occurs during subscription creation, a failed subscription is returned. +func (b *backendSubscribeTransactions) SubscribeTransactionStatusesFromStartBlockID( + ctx context.Context, + txID flow.Identifier, + startBlockID flow.Identifier, + requiredEventEncodingVersion entities.EventEncodingVersion, +) subscription.Subscription { + return b.createSubscription(ctx, txID, startBlockID, 0, flow.ZeroID, requiredEventEncodingVersion, false) +} + +// SubscribeTransactionStatusesFromLatest subscribes to the status updates of a transaction. +// Monitoring starts from the latest block. +// If the block cannot be retrieved or an error occurs during subscription creation, a failed subscription is returned. +func (b *backendSubscribeTransactions) SubscribeTransactionStatusesFromLatest( + ctx context.Context, + txID flow.Identifier, + requiredEventEncodingVersion entities.EventEncodingVersion, +) subscription.Subscription { + header, err := b.txLocalDataProvider.state.Sealed().Head() if err != nil { - return subscription.NewFailedSubscription(err, "could not get start height") + b.log.Error().Err(err).Msg("failed to retrieve latest block") + return subscription.NewFailedSubscription(err, "failed to retrieve latest block") } - txInfo := TransactionSubscriptionMetadata{ + return b.createSubscription(ctx, txID, header.ID(), 0, flow.ZeroID, requiredEventEncodingVersion, false) +} + +// createSubscription initializes a subscription for monitoring a transaction's status. +// If the start height cannot be determined, a failed subscription is returned. +func (b *backendSubscribeTransactions) createSubscription( + ctx context.Context, + txID flow.Identifier, + startBlockID flow.Identifier, + startBlockHeight uint64, + referenceBlockID flow.Identifier, + requiredEventEncodingVersion entities.EventEncodingVersion, + shouldTriggerPending bool, +) subscription.Subscription { + var nextHeight uint64 + var err error + + // Get height to start subscription from + if startBlockID == flow.ZeroID { + if nextHeight, err = b.blockTracker.GetStartHeightFromHeight(startBlockHeight); err != nil { + b.log.Error().Err(err).Uint64("block_height", startBlockHeight).Msg("failed to get start height") + return subscription.NewFailedSubscription(err, "failed to get start height") + } + } else { + if nextHeight, err = b.blockTracker.GetStartHeightFromBlockID(startBlockID); err != nil { + b.log.Error().Err(err).Str("block_id", startBlockID.String()).Msg("failed to get start height") + return subscription.NewFailedSubscription(err, "failed to get start height") + } + } + + // choose initial transaction status + initialStatus := flow.TransactionStatusUnknown + if shouldTriggerPending { + // The status of the first pending transaction should be returned immediately, as the transaction has already been sent. + // This should occur only once for each subscription. + initialStatus = flow.TransactionStatusPending + } + + txInfo := transactionSubscriptionMetadata{ TransactionResult: &access.TransactionResult{ - TransactionID: tx.ID(), + TransactionID: txID, BlockID: flow.ZeroID, - Status: flow.TransactionStatusUnknown, + Status: initialStatus, }, - txReferenceBlockID: tx.ReferenceBlockID, + txReferenceBlockID: referenceBlockID, blockWithTx: nil, eventEncodingVersion: requiredEventEncodingVersion, + shouldTriggerPending: shouldTriggerPending, } return b.subscriptionHandler.Subscribe(ctx, nextHeight, b.getTransactionStatusResponse(&txInfo)) @@ -69,16 +155,19 @@ func (b *backendSubscribeTransactions) SubscribeTransactionStatuses( // getTransactionStatusResponse returns a callback function that produces transaction status // subscription responses based on new blocks. -func (b *backendSubscribeTransactions) getTransactionStatusResponse(txInfo *TransactionSubscriptionMetadata) func(context.Context, uint64) (interface{}, error) { +func (b *backendSubscribeTransactions) getTransactionStatusResponse(txInfo *transactionSubscriptionMetadata) func(context.Context, uint64) (interface{}, error) { return func(ctx context.Context, height uint64) (interface{}, error) { err := b.checkBlockReady(height) if err != nil { return nil, err } - // If the transaction status already reported the final status, return with no data available - if txInfo.Status == flow.TransactionStatusSealed || txInfo.Status == flow.TransactionStatusExpired { - return nil, fmt.Errorf("transaction final status %s was already reported: %w", txInfo.Status.String(), subscription.ErrEndOfData) + if txInfo.shouldTriggerPending { + return b.handlePendingStatus(txInfo) + } + + if b.isTransactionFinalStatus(txInfo) { + return nil, fmt.Errorf("transaction final status %s already reported: %w", txInfo.Status.String(), subscription.ErrEndOfData) } // If on this step transaction block not available, search for it. @@ -120,19 +209,8 @@ func (b *backendSubscribeTransactions) getTransactionStatusResponse(txInfo *Tran } // If block with transaction was not found, get transaction status to check if it different from last status - if txInfo.blockWithTx == nil { - txInfo.Status, err = b.txLocalDataProvider.DeriveUnknownTransactionStatus(txInfo.txReferenceBlockID) - } else if txInfo.Status == prevTxStatus { - // When a block with the transaction is available, it is possible to receive a new transaction status while - // searching for the transaction result. Otherwise, it remains unchanged. So, if the old and new transaction - // statuses are the same, the current transaction status should be retrieved. - txInfo.Status, err = b.txLocalDataProvider.DeriveTransactionStatus(txInfo.blockWithTx.Height, txInfo.txExecuted) - } - if err != nil { - if !errors.Is(err, state.ErrUnknownSnapshotReference) { - irrecoverable.Throw(ctx, err) - } - return nil, rpc.ConvertStorageError(err) + if txInfo.Status, err = b.getTransactionStatus(ctx, txInfo, prevTxStatus); err != nil { + return nil, err } // If the old and new transaction statuses are still the same, the status change should not be reported, so @@ -145,6 +223,45 @@ func (b *backendSubscribeTransactions) getTransactionStatusResponse(txInfo *Tran } } +// handlePendingStatus handles the initial pending status for a transaction. +func (b *backendSubscribeTransactions) handlePendingStatus(txInfo *transactionSubscriptionMetadata) (interface{}, error) { + txInfo.shouldTriggerPending = false + return b.generateResultsWithMissingStatuses(txInfo, flow.TransactionStatusUnknown) +} + +// isTransactionFinalStatus checks if a transaction has reached a final state (Sealed or Expired). +func (b *backendSubscribeTransactions) isTransactionFinalStatus(txInfo *transactionSubscriptionMetadata) bool { + return txInfo.Status == flow.TransactionStatusSealed || txInfo.Status == flow.TransactionStatusExpired +} + +// getTransactionStatus determines the current status of a transaction based on its metadata +// and previous status. It derives the transaction status by analyzing the transaction's +// execution block, if available, or its reference block. +// +// No errors expected during normal operations. +func (b *backendSubscribeTransactions) getTransactionStatus(ctx context.Context, txInfo *transactionSubscriptionMetadata, prevTxStatus flow.TransactionStatus) (flow.TransactionStatus, error) { + txStatus := txInfo.Status + var err error + + if txInfo.blockWithTx == nil { + txStatus, err = b.txLocalDataProvider.DeriveUnknownTransactionStatus(txInfo.txReferenceBlockID) + } else if txStatus == prevTxStatus { + // When a block with the transaction is available, it is possible to receive a new transaction status while + // searching for the transaction result. Otherwise, it remains unchanged. So, if the old and new transaction + // statuses are the same, the current transaction status should be retrieved. + txStatus, err = b.txLocalDataProvider.DeriveTransactionStatus(txInfo.blockWithTx.Height, txInfo.txExecuted) + } + + if err != nil { + if !errors.Is(err, state.ErrUnknownSnapshotReference) { + irrecoverable.Throw(ctx, err) + } + return flow.TransactionStatusUnknown, rpc.ConvertStorageError(err) + } + + return txStatus, nil +} + // generateResultsWithMissingStatuses checks if the current result differs from the previous result by more than one step. // If yes, it generates results for the missing transaction statuses. This is done because the subscription should send // responses for each of the statuses in the transaction lifecycle, and the message should be sent in the order of transaction statuses. @@ -153,7 +270,7 @@ func (b *backendSubscribeTransactions) getTransactionStatusResponse(txInfo *Tran // 2. pending(1) -> expired(5) // No errors expected during normal operations. func (b *backendSubscribeTransactions) generateResultsWithMissingStatuses( - txInfo *TransactionSubscriptionMetadata, + txInfo *transactionSubscriptionMetadata, prevTxStatus flow.TransactionStatus, ) ([]*access.TransactionResult, error) { // If the previous status is pending and the new status is expired, which is the last status, return its result. @@ -228,7 +345,7 @@ func (b *backendSubscribeTransactions) checkBlockReady(height uint64) error { // - codes.Internal when other errors occur during block or collection lookup func (b *backendSubscribeTransactions) searchForTransactionBlockInfo( height uint64, - txInfo *TransactionSubscriptionMetadata, + txInfo *transactionSubscriptionMetadata, ) (*flow.Header, flow.Identifier, uint64, flow.Identifier, error) { block, err := b.txLocalDataProvider.blocks.ByHeight(height) if err != nil { @@ -252,7 +369,7 @@ func (b *backendSubscribeTransactions) searchForTransactionBlockInfo( // - codes.Internal if an internal error occurs while retrieving execution result. func (b *backendSubscribeTransactions) searchForTransactionResult( ctx context.Context, - txInfo *TransactionSubscriptionMetadata, + txInfo *transactionSubscriptionMetadata, ) (*access.TransactionResult, error) { _, err := b.executionResults.ByBlockID(txInfo.BlockID) if err != nil { diff --git a/engine/access/rpc/backend/backend_stream_transactions_test.go b/engine/access/rpc/backend/backend_stream_transactions_test.go index 24cdf601f17..52049a4b0ef 100644 --- a/engine/access/rpc/backend/backend_stream_transactions_test.go +++ b/engine/access/rpc/backend/backend_stream_transactions_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + accessproto "github.com/onflow/flow/protobuf/go/flow/access" + accessapi "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/access/index" @@ -121,6 +123,14 @@ func (s *TransactionStatusSuite) SetupTest() { s.blockTracker = subscriptionmock.NewBlockTracker(s.T()) s.resultsMap = map[flow.Identifier]*flow.ExecutionResult{} + s.colClient.On( + "SendTransaction", + mock.Anything, + mock.Anything, + ).Return(&accessproto.SendTransactionResponse{}, nil).Maybe() + + s.transactions.On("Store", mock.Anything).Return(nil).Maybe() + // generate blockCount consecutive blocks with associated seal, result and execution data s.rootBlock = unittest.BlockFixture() rootResult := unittest.ExecutionResultFixture(unittest.WithBlock(&s.rootBlock)) @@ -148,7 +158,7 @@ func (s *TransactionStatusSuite) SetupTest() { require.NoError(s.T(), err) s.blocks.On("ByHeight", mock.AnythingOfType("uint64")).Return(mocks.StorageMapGetter(s.blockMap)) - s.state.On("Final").Return(s.finalSnapshot, nil) + s.state.On("Final").Return(s.finalSnapshot, nil).Maybe() s.state.On("AtBlockID", mock.AnythingOfType("flow.Identifier")).Return(func(blockID flow.Identifier) protocolint.Snapshot { s.tempSnapshot.On("Head").Unset() s.tempSnapshot.On("Head").Return(func() *flow.Header { @@ -162,12 +172,12 @@ func (s *TransactionStatusSuite) SetupTest() { }, nil) return s.tempSnapshot - }, nil) + }, nil).Maybe() s.finalSnapshot.On("Head").Return(func() *flow.Header { finalizedHeader := s.finalizedBlock.Header return finalizedHeader - }, nil) + }, nil).Maybe() s.blockTracker.On("GetStartHeightFromBlockID", mock.Anything).Return(func(_ flow.Identifier) (uint64, error) { finalizedHeader := s.finalizedBlock.Header @@ -235,7 +245,7 @@ func (s *TransactionStatusSuite) addNewFinalizedBlock(parent *flow.Header, notif } } -// TestSubscribeTransactionStatusHappyCase tests the functionality of the SubscribeTransactionStatuses method in the Backend. +// TestSubscribeTransactionStatusHappyCase tests the functionality of the SubscribeTransactionStatusesFromStartBlockID method in the Backend. // It covers the emulation of transaction stages from pending to sealed, and receiving status updates. func (s *TransactionStatusSuite) TestSubscribeTransactionStatusHappyCase() { ctx, cancel := context.WithCancel(context.Background()) @@ -314,7 +324,7 @@ func (s *TransactionStatusSuite) TestSubscribeTransactionStatusHappyCase() { } // 1. Subscribe to transaction status and receive the first message with pending status - sub := s.backend.SubscribeTransactionStatuses(ctx, &transaction.TransactionBody, entities.EventEncodingVersion_CCF_V0) + sub := s.backend.SendAndSubscribeTransactionStatuses(ctx, &transaction.TransactionBody, entities.EventEncodingVersion_CCF_V0) checkNewSubscriptionMessage(sub, flow.TransactionStatusPending) // 2. Make transaction reference block sealed, and add a new finalized block that includes the transaction @@ -349,7 +359,7 @@ func (s *TransactionStatusSuite) TestSubscribeTransactionStatusHappyCase() { }, 100*time.Millisecond, "timed out waiting for subscription to shutdown") } -// TestSubscribeTransactionStatusExpired tests the functionality of the SubscribeTransactionStatuses method in the Backend +// TestSubscribeTransactionStatusExpired tests the functionality of the SubscribeTransactionStatusesFromStartBlockID method in the Backend // when transaction become expired func (s *TransactionStatusSuite) TestSubscribeTransactionStatusExpired() { ctx, cancel := context.WithCancel(context.Background()) @@ -380,7 +390,7 @@ func (s *TransactionStatusSuite) TestSubscribeTransactionStatusExpired() { } // Subscribe to transaction status and receive the first message with pending status - sub := s.backend.SubscribeTransactionStatuses(ctx, &transaction.TransactionBody, entities.EventEncodingVersion_CCF_V0) + sub := s.backend.SendAndSubscribeTransactionStatuses(ctx, &transaction.TransactionBody, entities.EventEncodingVersion_CCF_V0) checkNewSubscriptionMessage(sub, flow.TransactionStatusPending) // Generate 600 blocks without transaction included and check, that transaction still pending