diff --git a/config/consumer_examples/lava_consumer_static_peers.yml b/config/consumer_examples/lava_consumer_static_peers.yml index 5e3a6bfe5f..415f1f4c95 100644 --- a/config/consumer_examples/lava_consumer_static_peers.yml +++ b/config/consumer_examples/lava_consumer_static_peers.yml @@ -20,4 +20,16 @@ static-providers: - api-interface: rest chain-id: LAV1 node-urls: - - url: 127.0.0.1:2220 \ No newline at end of file + - url: 127.0.0.1:2220 + - api-interface: tendermintrpc + chain-id: LAV1 + node-urls: + - url: 127.0.0.1:2221 + - api-interface: grpc + chain-id: LAV1 + node-urls: + - url: 127.0.0.1:2221 + - api-interface: rest + chain-id: LAV1 + node-urls: + - url: 127.0.0.1:2221 \ No newline at end of file diff --git a/config/provider_examples/lava_example2.yml b/config/provider_examples/lava_example2.yml new file mode 100644 index 0000000000..37a1cfc68f --- /dev/null +++ b/config/provider_examples/lava_example2.yml @@ -0,0 +1,20 @@ +endpoints: + - api-interface: tendermintrpc + chain-id: LAV1 + network-address: + address: "127.0.0.1:2221" + node-urls: + - url: ws://127.0.0.1:26657/websocket + - url: http://127.0.0.1:26657 + - api-interface: grpc + chain-id: LAV1 + network-address: + address: "127.0.0.1:2221" + node-urls: + - url: 127.0.0.1:9090 + - api-interface: rest + chain-id: LAV1 + network-address: + address: "127.0.0.1:2221" + node-urls: + - url: http://127.0.0.1:1317 \ No newline at end of file diff --git a/protocol/chainlib/chain_router_test.go b/protocol/chainlib/chain_router_test.go index db1b9fa8d6..e2fe5e1c0a 100644 --- a/protocol/chainlib/chain_router_test.go +++ b/protocol/chainlib/chain_router_test.go @@ -1175,8 +1175,8 @@ func TestMain(m *testing.M) { listener := createRPCServer() for { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := rpcclient.DialContext(ctx, listenerAddressHttp) - _, err2 := rpcclient.DialContext(ctx, listenerAddressWs) + _, err := rpcclient.DialContext(ctx, listenerAddressHttp, nil) + _, err2 := rpcclient.DialContext(ctx, listenerAddressWs, nil) if err2 != nil { utils.LavaFormatDebug("waiting for grpc server to launch") continue diff --git a/protocol/chainlib/chainproxy/connector.go b/protocol/chainlib/chainproxy/connector.go index 307f86915a..f5c7d51ff7 100644 --- a/protocol/chainlib/chainproxy/connector.go +++ b/protocol/chainlib/chainproxy/connector.go @@ -107,6 +107,18 @@ func (connector *Connector) numberOfUsedClients() int { return int(atomic.LoadInt64(&connector.usedClients)) } +func (connector *Connector) getRpcClient(ctx context.Context, nodeUrl common.NodeUrl) (*rpcclient.Client, error) { + authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url) + // origin used for auth header in the websocket case + authHeaders := nodeUrl.GetAuthHeaders() + rpcClient, err := rpcclient.DialContext(ctx, authPathNodeUrl, authHeaders) + if err != nil { + return nil, err + } + nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader) + return rpcClient, nil +} + func (connector *Connector) createConnection(ctx context.Context, nodeUrl common.NodeUrl, currentNumberOfConnections int) (*rpcclient.Client, error) { var rpcClient *rpcclient.Client var err error @@ -124,21 +136,13 @@ func (connector *Connector) createConnection(ctx context.Context, nodeUrl common } timeout := common.AverageWorldLatency * (1 + time.Duration(numberOfConnectionAttempts)) nctx, cancel := nodeUrl.LowerContextTimeoutWithDuration(ctx, timeout) - // add auth path - authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url) - rpcClient, err = rpcclient.DialContext(nctx, authPathNodeUrl) + // get rpcClient + rpcClient, err = connector.getRpcClient(nctx, nodeUrl) if err != nil { - utils.LavaFormatWarning("Could not connect to the node, retrying", err, []utils.Attribute{ - {Key: "Current Number Of Connections", Value: currentNumberOfConnections}, - {Key: "Network Address", Value: authPathNodeUrl}, - {Key: "Number Of Attempts", Value: numberOfConnectionAttempts}, - {Key: "timeout", Value: timeout}, - }...) cancel() continue } cancel() - nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader) break } @@ -178,7 +182,8 @@ func (connector *Connector) increaseNumberOfClients(ctx context.Context, numberO var err error for connectionAttempt := 0; connectionAttempt < MaximumNumberOfParallelConnectionsAttempts; connectionAttempt++ { nctx, cancel := connector.nodeUrl.LowerContextTimeoutWithDuration(ctx, common.AverageWorldLatency*2) - rpcClient, err = rpcclient.DialContext(nctx, connector.nodeUrl.Url) + // get rpcClient + rpcClient, err = connector.getRpcClient(nctx, connector.nodeUrl) if err != nil { utils.LavaFormatDebug( "could no increase number of connections to the node jsonrpc connector, retrying", diff --git a/protocol/chainlib/chainproxy/connector_test.go b/protocol/chainlib/chainproxy/connector_test.go index e408ae2062..26a8100425 100644 --- a/protocol/chainlib/chainproxy/connector_test.go +++ b/protocol/chainlib/chainproxy/connector_test.go @@ -2,6 +2,7 @@ package chainproxy import ( "context" + "encoding/json" "fmt" "log" "net" @@ -16,6 +17,7 @@ import ( "github.com/lavanet/lava/v4/utils" pb_pkg "github.com/lavanet/lava/v4/x/spec/types" "github.com/stretchr/testify/require" + "golang.org/x/net/websocket" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -185,7 +187,7 @@ func TestMain(m *testing.M) { listener := createRPCServer() for { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := rpcclient.DialContext(ctx, listenerAddressTcp) + _, err := rpcclient.DialContext(ctx, listenerAddressTcp, nil) if err != nil { utils.LavaFormatDebug("waiting for grpc server to launch") continue @@ -199,3 +201,89 @@ func TestMain(m *testing.M) { listener.Close() os.Exit(code) } + +func TestConnectorWebsocket(t *testing.T) { + // Set up auth headers we expect + expectedAuthHeader := "Bearer test-token" + + // Create WebSocket server with auth check + srv := &http.Server{ + Addr: "localhost:0", // random available port + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check auth header + authHeader := r.Header.Get("Authorization") + if authHeader != expectedAuthHeader { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + fmt.Println("connection OK!") + // Upgrade to websocket + upgrader := websocket.Server{ + Handler: websocket.Handler(func(ws *websocket.Conn) { + defer ws.Close() + // Simple echo server + for { + var msg string + err := websocket.Message.Receive(ws, &msg) + if err != nil { + break + } + websocket.Message.Send(ws, msg) + } + }), + } + upgrader.ServeHTTP(w, r) + }), + } + + // Start server + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer listener.Close() + + go srv.Serve(listener) + wsURL := "ws://" + listener.Addr().String() + + // Create connector with auth config + ctx := context.Background() + nodeUrl := common.NodeUrl{ + Url: wsURL, + AuthConfig: common.AuthConfig{ + AuthHeaders: map[string]string{ + "Authorization": expectedAuthHeader, + }, + }, + } + + // Create connector + conn, err := NewConnector(ctx, numberOfClients, nodeUrl) + require.NoError(t, err) + defer conn.Close() + + // Wait for connections to be established + for { + if len(conn.freeClients) == numberOfClients { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Get a client and test the connection + client, err := conn.GetRpc(ctx, true) + require.NoError(t, err) + + // Test sending a message using CallContext + params := map[string]interface{}{ + "test": "value", + } + id := json.RawMessage(`1`) + _, err = client.CallContext(ctx, id, "test_method", params, true, true) + require.NoError(t, err) + + // Return the client + conn.ReturnRpc(client) + + // Verify connection pool state + require.Equal(t, int64(0), conn.usedClients) + require.Equal(t, numberOfClients, len(conn.freeClients)) +} diff --git a/protocol/chainlib/chainproxy/rpcclient/client.go b/protocol/chainlib/chainproxy/rpcclient/client.go index c745f82ee2..15fcdb9b73 100644 --- a/protocol/chainlib/chainproxy/rpcclient/client.go +++ b/protocol/chainlib/chainproxy/rpcclient/client.go @@ -177,14 +177,14 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*JsonrpcMessage, erro // // The client reconnects automatically if the connection is lost. func Dial(rawurl string) (*Client, error) { - return DialContext(context.Background(), rawurl) + return DialContext(context.Background(), rawurl, nil) } // DialContext creates a new RPC client, just like Dial. // // The context is used to cancel or time out the initial connection establishment. It does // not affect subsequent interactions with the client. -func DialContext(ctx context.Context, rawurl string) (*Client, error) { +func DialContext(ctx context.Context, rawurl string, wsHeaders map[string]string) (*Client, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err @@ -193,7 +193,7 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) { case "http", "https": return DialHTTP(rawurl) case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") + return DialWebsocket(ctx, rawurl, wsHeaders) case "stdio": return DialStdIO(ctx) case "": diff --git a/protocol/chainlib/chainproxy/rpcclient/websocket.go b/protocol/chainlib/chainproxy/rpcclient/websocket.go index 81566ffe0f..680d37167a 100755 --- a/protocol/chainlib/chainproxy/rpcclient/websocket.go +++ b/protocol/chainlib/chainproxy/rpcclient/websocket.go @@ -187,8 +187,8 @@ func parseOriginURL(origin string) (string, string, string, error) { // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. -func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { - endpoint, header, err := wsClientHeaders(endpoint, origin) +func DialWebsocketWithDialer(ctx context.Context, endpoint string, dialer websocket.Dialer, headers map[string]string) (*Client, error) { + endpoint, header, err := wsClientHeaders(endpoint, headers) if err != nil { return nil, err } @@ -210,23 +210,23 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale // // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. -func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { +func DialWebsocket(ctx context.Context, endpoint string, headers map[string]string) (*Client, error) { dialer := websocket.Dialer{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, WriteBufferPool: wsBufferPool, } - return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) + return DialWebsocketWithDialer(ctx, endpoint, dialer, headers) } -func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { +func wsClientHeaders(endpoint string, headers map[string]string) (string, http.Header, error) { endpointURL, err := url.Parse(endpoint) if err != nil { return endpoint, nil, err } header := make(http.Header) - if origin != "" { - header.Add("origin", origin) + for headerKey, headerValue := range headers { + header.Add(headerKey, headerValue) } if endpointURL.User != nil { b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) diff --git a/protocol/common/endpoints.go b/protocol/common/endpoints.go index 2379512708..098b301b57 100644 --- a/protocol/common/endpoints.go +++ b/protocol/common/endpoints.go @@ -91,6 +91,10 @@ func (nurl *NodeUrl) UrlStr() string { return parsedURL.String() } +func (url *NodeUrl) GetAuthHeaders() map[string]string { + return url.AuthConfig.AuthHeaders +} + func (url *NodeUrl) SetAuthHeaders(ctx context.Context, headerSetter func(string, string)) { for header, headerValue := range url.AuthConfig.AuthHeaders { headerSetter(header, headerValue) diff --git a/protocol/integration/mocks.go b/protocol/integration/mocks.go index 55ec653fe2..248db05032 100644 --- a/protocol/integration/mocks.go +++ b/protocol/integration/mocks.go @@ -39,7 +39,7 @@ func (m *mockConsumerStateTracker) RegisterForSpecUpdates(ctx context.Context, s return nil } -func (m *mockConsumerStateTracker) RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus) { +func (m *mockConsumerStateTracker) RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus, bool) { } func (m *mockConsumerStateTracker) RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error { diff --git a/protocol/lavasession/consumer_session_manager.go b/protocol/lavasession/consumer_session_manager.go index efa4d43352..c1e179edde 100644 --- a/protocol/lavasession/consumer_session_manager.go +++ b/protocol/lavasession/consumer_session_manager.go @@ -75,6 +75,7 @@ func (csm *ConsumerSessionManager) RPCEndpoint() RPCEndpoint { } func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList map[uint64]*ConsumerSessionsWithProvider) error { + utils.LavaFormatDebug("UpdateAllProviders", utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "pairingListLen", Value: len(pairingList)}) pairingListLength := len(pairingList) // TODO: we can block updating until some of the probing is done, this can prevent failed attempts on epoch change when we have no information on the providers, // and all of them are new (less effective on big pairing lists or a process that runs for a few epochs) diff --git a/protocol/rpcconsumer/consumer_state_tracker_mock.go b/protocol/rpcconsumer/consumer_state_tracker_mock.go index 11663b29e7..5e25fa4d88 100644 --- a/protocol/rpcconsumer/consumer_state_tracker_mock.go +++ b/protocol/rpcconsumer/consumer_state_tracker_mock.go @@ -27,6 +27,7 @@ import ( type MockConsumerStateTrackerInf struct { ctrl *gomock.Controller recorder *MockConsumerStateTrackerInfMockRecorder + isgomock struct{} } // MockConsumerStateTrackerInfMockRecorder is the mock recorder for MockConsumerStateTrackerInf. @@ -91,27 +92,27 @@ func (mr *MockConsumerStateTrackerInfMockRecorder) GetProtocolVersion(ctx any) * } // RegisterConsumerSessionManagerForPairingUpdates mocks base method. -func (m *MockConsumerStateTrackerInf) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProviders []*lavasession.RPCProviderEndpoint) { +func (m *MockConsumerStateTrackerInf) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProvidersList []*lavasession.RPCProviderEndpoint) { m.ctrl.T.Helper() - m.ctrl.Call(m, "RegisterConsumerSessionManagerForPairingUpdates", ctx, consumerSessionManager) + m.ctrl.Call(m, "RegisterConsumerSessionManagerForPairingUpdates", ctx, consumerSessionManager, staticProvidersList) } // RegisterConsumerSessionManagerForPairingUpdates indicates an expected call of RegisterConsumerSessionManagerForPairingUpdates. -func (mr *MockConsumerStateTrackerInfMockRecorder) RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager any) *gomock.Call { +func (mr *MockConsumerStateTrackerInfMockRecorder) RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager, staticProvidersList any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterConsumerSessionManagerForPairingUpdates", reflect.TypeOf((*MockConsumerStateTrackerInf)(nil).RegisterConsumerSessionManagerForPairingUpdates), ctx, consumerSessionManager) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterConsumerSessionManagerForPairingUpdates", reflect.TypeOf((*MockConsumerStateTrackerInf)(nil).RegisterConsumerSessionManagerForPairingUpdates), ctx, consumerSessionManager, staticProvidersList) } // RegisterFinalizationConsensusForUpdates mocks base method. -func (m *MockConsumerStateTrackerInf) RegisterFinalizationConsensusForUpdates(arg0 context.Context, arg1 *finalizationconsensus.FinalizationConsensus) { +func (m *MockConsumerStateTrackerInf) RegisterFinalizationConsensusForUpdates(arg0 context.Context, arg1 *finalizationconsensus.FinalizationConsensus, arg2 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "RegisterFinalizationConsensusForUpdates", arg0, arg1) + m.ctrl.Call(m, "RegisterFinalizationConsensusForUpdates", arg0, arg1, arg2) } // RegisterFinalizationConsensusForUpdates indicates an expected call of RegisterFinalizationConsensusForUpdates. -func (mr *MockConsumerStateTrackerInfMockRecorder) RegisterFinalizationConsensusForUpdates(arg0, arg1 any) *gomock.Call { +func (mr *MockConsumerStateTrackerInfMockRecorder) RegisterFinalizationConsensusForUpdates(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterFinalizationConsensusForUpdates", reflect.TypeOf((*MockConsumerStateTrackerInf)(nil).RegisterFinalizationConsensusForUpdates), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterFinalizationConsensusForUpdates", reflect.TypeOf((*MockConsumerStateTrackerInf)(nil).RegisterFinalizationConsensusForUpdates), arg0, arg1, arg2) } // RegisterForDowntimeParamsUpdates mocks base method. diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 4687d07ec3..c141fb97cb 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -96,7 +96,7 @@ type ConsumerStateTrackerInf interface { RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProvidersList []*lavasession.RPCProviderEndpoint) RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error - RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus) + RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus, bool) RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, conflictHandler common.ConflictHandlerInterface) error GetConsumerPolicy(ctx context.Context, consumerAddress, chainID string) (*plantypes.Policy, error) @@ -348,6 +348,15 @@ func (rpcc *RPCConsumer) CreateConsumerEndpoint( return nil, err } + // Filter the relevant static providers + relevantStaticProviderList := []*lavasession.RPCProviderEndpoint{} + for _, staticProvider := range options.staticProvidersList { + if staticProvider.ChainID == rpcEndpoint.ChainID { + relevantStaticProviderList = append(relevantStaticProviderList, staticProvider) + } + } + staticProvidersActive := len(relevantStaticProviderList) > 0 + _, averageBlockTime, _, _ := chainParser.ChainBlockStats() var optimizer *provideroptimizer.ProviderOptimizer var consumerConsistency *ConsumerConsistency @@ -387,7 +396,7 @@ func (rpcc *RPCConsumer) CreateConsumerEndpoint( return utils.LavaFormatError("failed loading finalization consensus", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) } if !loaded { // when creating new finalization consensus instance we need to register it to updates - consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus) + consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus, staticProvidersActive) } return nil } diff --git a/protocol/statetracker/consumer_state_tracker.go b/protocol/statetracker/consumer_state_tracker.go index db91d8312e..da6a1c8b7c 100644 --- a/protocol/statetracker/consumer_state_tracker.go +++ b/protocol/statetracker/consumer_state_tracker.go @@ -94,8 +94,8 @@ func (cst *ConsumerStateTracker) RegisterForPairingUpdates(ctx context.Context, } } -func (cst *ConsumerStateTracker) RegisterFinalizationConsensusForUpdates(ctx context.Context, finalizationConsensus *finalizationconsensus.FinalizationConsensus) { - finalizationConsensusUpdater := updaters.NewFinalizationConsensusUpdater(cst.StateQuery, finalizationConsensus.SpecId) +func (cst *ConsumerStateTracker) RegisterFinalizationConsensusForUpdates(ctx context.Context, finalizationConsensus *finalizationconsensus.FinalizationConsensus, ignoreQueryErrors bool) { + finalizationConsensusUpdater := updaters.NewFinalizationConsensusUpdater(cst.StateQuery, finalizationConsensus.SpecId, ignoreQueryErrors) finalizationConsensusUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, finalizationConsensusUpdater) finalizationConsensusUpdater, ok := finalizationConsensusUpdaterRaw.(*updaters.FinalizationConsensusUpdater) if !ok { diff --git a/protocol/statetracker/updaters/finalization_consensus_updater.go b/protocol/statetracker/updaters/finalization_consensus_updater.go index d6aa4bb1d6..b5cd25e87a 100644 --- a/protocol/statetracker/updaters/finalization_consensus_updater.go +++ b/protocol/statetracker/updaters/finalization_consensus_updater.go @@ -19,10 +19,11 @@ type FinalizationConsensusUpdater struct { nextBlockForUpdate uint64 stateQuery *ConsumerStateQuery specId string + ignoreQueryErrors bool // used when static providers are configured so we don't spam errors on failed get pairing. } -func NewFinalizationConsensusUpdater(stateQuery *ConsumerStateQuery, specId string) *FinalizationConsensusUpdater { - return &FinalizationConsensusUpdater{registeredFinalizationConsensuses: []*finalizationconsensus.FinalizationConsensus{}, stateQuery: stateQuery, specId: specId} +func NewFinalizationConsensusUpdater(stateQuery *ConsumerStateQuery, specId string, ignoreQueryErrors bool) *FinalizationConsensusUpdater { + return &FinalizationConsensusUpdater{registeredFinalizationConsensuses: []*finalizationconsensus.FinalizationConsensus{}, stateQuery: stateQuery, specId: specId, ignoreQueryErrors: ignoreQueryErrors} } func (fcu *FinalizationConsensusUpdater) RegisterFinalizationConsensus(finalizationConsensus *finalizationconsensus.FinalizationConsensus) { @@ -45,8 +46,10 @@ func (fcu *FinalizationConsensusUpdater) updateInner(latestBlock int64) { timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second*3) defer cancel() _, epoch, nextBlockForUpdate, err := fcu.stateQuery.GetPairing(timeoutCtx, fcu.specId, latestBlock) - if err != nil { - utils.LavaFormatError("could not get block stats for finalization consensus updater, trying again next block", err, utils.Attribute{Key: "latestBlock", Value: latestBlock}) + if err != nil && epoch == 0 { + if !fcu.ignoreQueryErrors { + utils.LavaFormatError("could not get block stats for finalization consensus updater, trying again next block", err, utils.Attribute{Key: "latestBlock", Value: latestBlock}) + } fcu.nextBlockForUpdate += 1 return } diff --git a/protocol/statetracker/updaters/pairing_updater.go b/protocol/statetracker/updaters/pairing_updater.go index 661c2dc9d9..64bed77df0 100644 --- a/protocol/statetracker/updaters/pairing_updater.go +++ b/protocol/statetracker/updaters/pairing_updater.go @@ -46,7 +46,7 @@ func NewPairingUpdater(stateQuery ConsumerStateQueryInf, specId string) *Pairing return &PairingUpdater{consumerSessionManagersMap: map[string][]ConsumerSessionManagerInf{}, stateQuery: stateQuery, specId: specId, staticProviders: []*lavasession.RPCProviderEndpoint{}} } -func (pu *PairingUpdater) updateStaticProviders(staticProviders []*lavasession.RPCProviderEndpoint) { +func (pu *PairingUpdater) updateStaticProviders(staticProviders []*lavasession.RPCProviderEndpoint) int { pu.lock.Lock() defer pu.lock.Unlock() if len(staticProviders) > 0 && len(pu.staticProviders) == 0 { @@ -56,6 +56,14 @@ func (pu *PairingUpdater) updateStaticProviders(staticProviders []*lavasession.R } } } + // return length of relevant static providers + return len(pu.staticProviders) +} + +func (pu *PairingUpdater) getNumberOfStaticProviders() int { + pu.lock.RLock() + defer pu.lock.RUnlock() + return len(pu.staticProviders) } func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionManager ConsumerSessionManagerInf, staticProviders []*lavasession.RPCProviderEndpoint) error { @@ -63,10 +71,10 @@ func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionMa timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() pairingList, epoch, nextBlockForUpdate, err := pu.stateQuery.GetPairing(timeoutCtx, chainID, -1) - if err != nil { + numberOfRelevantProviders := pu.updateStaticProviders(staticProviders) + if err != nil && (epoch == 0 || numberOfRelevantProviders == 0) { return err } - pu.updateStaticProviders(staticProviders) pu.updateConsumerSessionManager(ctx, pairingList, consumerSessionManager, epoch) if nextBlockForUpdate > pu.nextBlockForUpdate { // make sure we don't update twice, this updates pu.nextBlockForUpdate @@ -87,7 +95,7 @@ func (pu *PairingUpdater) RegisterPairingUpdatable(ctx context.Context, pairingU pu.lock.Lock() defer pu.lock.Unlock() _, epoch, _, err := pu.stateQuery.GetPairing(ctx, pu.specId, -1) - if err != nil { + if err != nil && (epoch == 0 || len(pu.staticProviders) == 0) { return err } @@ -104,23 +112,20 @@ func (pu *PairingUpdater) updateInner(latestBlock int64) { pu.lock.RLock() defer pu.lock.RUnlock() ctx := context.Background() - if int64(pu.nextBlockForUpdate) > latestBlock { return } nextBlockForUpdateList := []uint64{} for chainID, consumerSessionManagerList := range pu.consumerSessionManagersMap { timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() pairingList, epoch, nextBlockForUpdate, err := pu.stateQuery.GetPairing(timeoutCtx, chainID, latestBlock) - cancel() - if err != nil { + nextBlockForUpdateList = append(nextBlockForUpdateList, nextBlockForUpdate) + if err != nil && (epoch == 0 || len(pu.staticProviders) == 0) { + // it's ok that we don't have pairing, only if there are static providers and epoch is not 0 utils.LavaFormatError("could not update pairing for chain, trying again next block", err, utils.Attribute{Key: "chain", Value: chainID}) - nextBlockForUpdateList = append(nextBlockForUpdateList, pu.nextBlockForUpdate+1) continue - } else { - nextBlockForUpdateList = append(nextBlockForUpdateList, nextBlockForUpdate) } - for _, consumerSessionManager := range consumerSessionManagerList { // same pairing for all apiInterfaces, they pick the right endpoints from inside using our filter function err = pu.updateConsumerSessionManager(ctx, pairingList, consumerSessionManager, epoch) @@ -134,11 +139,15 @@ func (pu *PairingUpdater) updateInner(latestBlock int64) { // get latest epoch from cache timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - _, epoch, _, err := pu.stateQuery.GetPairing(timeoutCtx, pu.specId, latestBlock) - if err != nil { + _, epoch, nextPairingUpdateBlock, err := pu.stateQuery.GetPairing(timeoutCtx, pu.specId, latestBlock) + if err != nil && (epoch == 0 || len(pu.staticProviders) == 0) { utils.LavaFormatError("could not update pairing for updatables, trying again next block", err) + } + + if epoch == 0 { nextBlockForUpdateList = append(nextBlockForUpdateList, pu.nextBlockForUpdate+1) } else { + nextBlockForUpdateList = append(nextBlockForUpdateList, nextPairingUpdateBlock) for _, updatable := range pu.pairingUpdatables { (*updatable).UpdateEpoch(epoch) } @@ -192,10 +201,12 @@ func (pu *PairingUpdater) addStaticProvidersToPairingList(pairingList map[uint64 for _, extension := range url.Addons { extensions[extension] = struct{}{} } + + // TODO might be problematic adding both addons and extensions with same map. endpoint := &lavasession.Endpoint{ NetworkAddress: url.Url, Enabled: true, - Addons: map[string]struct{}{}, // TODO: does not support addons, if required need to add the functionality to differentiate the two + Addons: extensions, Extensions: extensions, Connections: []*lavasession.EndpointConnection{}, } @@ -261,7 +272,7 @@ func (pu *PairingUpdater) filterPairingListByEndpoint(ctx context.Context, curre totalStakeIncludingDelegation, ) } - if len(pairing) == 0 { + if len(pairing)+pu.getNumberOfStaticProviders() == 0 { return nil, utils.LavaFormatError("Failed getting pairing for consumer, pairing is empty", err, utils.Attribute{Key: "apiInterface", Value: rpcEndpoint.ApiInterface}, utils.Attribute{Key: "ChainID", Value: rpcEndpoint.ChainID}, utils.Attribute{Key: "geolocation", Value: rpcEndpoint.Geolocation}) } // replace previous pairing with new providers diff --git a/protocol/statetracker/updaters/state_query.go b/protocol/statetracker/updaters/state_query.go index a549108659..1109264b97 100644 --- a/protocol/statetracker/updaters/state_query.go +++ b/protocol/statetracker/updaters/state_query.go @@ -14,7 +14,9 @@ import ( reliabilitymanager "github.com/lavanet/lava/v4/protocol/rpcprovider/reliabilitymanager" "github.com/lavanet/lava/v4/utils" conflicttypes "github.com/lavanet/lava/v4/x/conflict/types" + epochkeeper "github.com/lavanet/lava/v4/x/epochstorage/keeper" epochstoragetypes "github.com/lavanet/lava/v4/x/epochstorage/types" + pairingkeeper "github.com/lavanet/lava/v4/x/pairing/keeper" pairingtypes "github.com/lavanet/lava/v4/x/pairing/types" plantypes "github.com/lavanet/lava/v4/x/plans/types" protocoltypes "github.com/lavanet/lava/v4/x/protocol/types" @@ -183,7 +185,17 @@ func (csq *ConsumerStateQuery) GetPairing(ctx context.Context, chainID string, l Client: csq.fromAddress, }) if err != nil { - return nil, 0, 0, utils.LavaFormatError("Failed in get pairing query", err, utils.Attribute{}) + // if we can't get pairing, try to get epoch details and params + epochParamsResp, epochParamsErr := csq.epochStorageQueryClient.Params(ctx, &epochstoragetypes.QueryParamsRequest{}) + epochDetailsResp, epochDetailsErr := csq.epochStorageQueryClient.EpochDetails(ctx, &epochstoragetypes.QueryGetEpochDetailsRequest{}) + pairingParamsResp, pairingParamsErr := csq.pairingQueryClient.Params(ctx, &pairingtypes.QueryParamsRequest{}) + if epochDetailsErr != nil || epochParamsErr != nil || pairingParamsErr != nil { + return nil, 0, 0, err // if we can't get epoch details or params, return the original error + } + + nextEpochBlock := epochkeeper.CalculateNextEpochBlock(epochDetailsResp.EpochDetails.StartBlock, epochParamsResp.Params.EpochBlocks) + nextPairingBlock := pairingkeeper.CalculateNextPairingUpdateBlock(nextEpochBlock, pairingParamsResp.Params.EpochBlocksOverlap) + return nil, epochDetailsResp.EpochDetails.StartBlock, nextPairingBlock, err } csq.lastChainID = chainID csq.ResponsesCache.SetWithTTL(PairingRespKey+chainID, pairingResp, 1, DefaultTimeToLiveExpiration) diff --git a/scripts/pre_setups/init_lava_static_provider.sh b/scripts/pre_setups/init_lava_static_provider.sh index 0ad135c638..051c158231 100755 --- a/scripts/pre_setups/init_lava_static_provider.sh +++ b/scripts/pre_setups/init_lava_static_provider.sh @@ -44,15 +44,18 @@ PROVIDER4_LISTENER="127.0.0.1:2220" lavad tx subscription buy DefaultPlan $(lavad keys show user1 -a) -y --from user1 --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE wait_next_block -lavad tx pairing stake-provider "LAV1" $PROVIDERSTAKE "$PROVIDER1_LISTENER,1" 1 $(operator_address) -y --from servicer1 --provider-moniker "dummyMoniker" --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +# lavad tx pairing stake-provider "LAV1" $PROVIDERSTAKE "$PROVIDER1_LISTENER,1" 1 $(operator_address) -y --from servicer1 --provider-moniker "dummyMoniker" --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE sleep_until_next_epoch screen -d -m -S provider4 bash -c "source ~/.bashrc; lavap rpcprovider provider_examples/lava_example.yml\ $EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level debug --from servicer4 --static-providers --chain-id lava 2>&1 | tee $LOGS_DIR/PROVIDER4.log" && sleep 0.25 +screen -d -m -S provider3 bash -c "source ~/.bashrc; lavap rpcprovider provider_examples/lava_example2.yml\ +$EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level debug --from servicer3 --static-providers --chain-id lava 2>&1 | tee $LOGS_DIR/PROVIDER3.log" && sleep 0.25 + screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer consumer_examples/lava_consumer_static_peers.yml \ -$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level debug --from user1 --chain-id lava --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 +$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level debug --from user1 --chain-id lava --allow-insecure-provider-dialing --metrics-listen-address ":7779" --enable-provider-optimizer-auto-adjustment-of-tiers --use-lava-over-lava-backup=false 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 echo "--- setting up screens done ---" screen -ls \ No newline at end of file diff --git a/scripts/test/init_payment_e2e.sh b/scripts/test/init_payment_e2e.sh index 7e662fd52a..98870657f3 100755 --- a/scripts/test/init_payment_e2e.sh +++ b/scripts/test/init_payment_e2e.sh @@ -27,6 +27,7 @@ sleep 6 STAKE="500000000000ulava" # Lava tendermint/rest providers +wait_next_block lavad tx pairing stake-provider "LAV1" $STAKE "127.0.0.1:2261,1" 1 $(operator_address) -y --from servicer1 --provider-moniker "dummyMoniker" --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE lavad tx pairing stake-provider "LAV1" $STAKE "127.0.0.1:2262,1" 1 $(operator_address) -y --from servicer2 --provider-moniker "dummyMoniker" --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE diff --git a/x/epochstorage/keeper/params.go b/x/epochstorage/keeper/params.go index 13dc516a32..da6cfcd38d 100644 --- a/x/epochstorage/keeper/params.go +++ b/x/epochstorage/keeper/params.go @@ -105,10 +105,14 @@ func (k Keeper) GetEpochStartForBlock(ctx sdk.Context, block uint64) (epochStart return targetEpochStart, blockInTargetEpoch, err } +func CalculateNextEpochBlock(epochStart uint64, epochBlocks uint64) uint64 { + return epochStart + epochBlocks +} + func (k Keeper) GetNextEpoch(ctx sdk.Context, block uint64) (nextEpoch uint64, erro error) { epochBlocks, err := k.EpochBlocks(ctx, block) epochStart, _, err2 := k.GetEpochStartForBlock(ctx, block) - nextEpoch = epochStart + epochBlocks + nextEpoch = CalculateNextEpochBlock(epochStart, epochBlocks) if err != nil { erro = err } else if err2 != nil { @@ -121,7 +125,7 @@ func (k Keeper) GetCurrentNextEpoch(ctx sdk.Context) (nextEpoch uint64) { epochBlocks := k.EpochBlocksRaw(ctx) details, found := k.GetEpochDetails(ctx) if details.EarliestStart == details.StartBlock { - nextEpoch = details.StartBlock + epochBlocks + nextEpoch = CalculateNextEpochBlock(details.StartBlock, epochBlocks) if !found { utils.LavaFormatPanic("blabla", nil) } diff --git a/x/pairing/keeper/pairing_next_epoch_time_block.go b/x/pairing/keeper/pairing_next_epoch_time_block.go index 881876704e..94253c2a5c 100644 --- a/x/pairing/keeper/pairing_next_epoch_time_block.go +++ b/x/pairing/keeper/pairing_next_epoch_time_block.go @@ -16,6 +16,10 @@ const ( MIN_SAMPLE_STEP uint64 = 1 // the minimal sample step when calculating the average block time ) +func CalculateNextPairingUpdateBlock(nextEpochBlock uint64, epochBlocksOverlap uint64) uint64 { + return nextEpochBlock + epochBlocksOverlap +} + // Function to calculate how much time (in seconds) is left until the next epoch func (k Keeper) calculateNextEpochTimeAndBlock(ctx sdk.Context) (uint64, uint64, error) { // Get current epoch @@ -37,7 +41,7 @@ func (k Keeper) calculateNextEpochTimeAndBlock(ctx sdk.Context) (uint64, uint64, overlapBlocks := k.EpochBlocksOverlap(ctx) // calculate the block in which the next pairing will happen (+overlap) - nextPairingBlock := nextEpochStart + overlapBlocks + nextPairingBlock := CalculateNextPairingUpdateBlock(nextEpochStart, overlapBlocks) // Get number of blocks from the current block to the next epoch blocksUntilNewEpoch := nextPairingBlock - uint64(ctx.BlockHeight())