Skip to content

Commit

Permalink
Merge branch 'main' into CNS-1003-reputation-proto-definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
Yaroms authored Jan 19, 2025
2 parents 543a6f9 + ec071f0 commit 7e6d535
Show file tree
Hide file tree
Showing 20 changed files with 241 additions and 63 deletions.
14 changes: 13 additions & 1 deletion config/consumer_examples/lava_consumer_static_peers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,16 @@ static-providers:
- api-interface: rest
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2220
- url: 127.0.0.1:2220
- api-interface: tendermintrpc
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2221
- api-interface: grpc
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2221
- api-interface: rest
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2221
20 changes: 20 additions & 0 deletions config/provider_examples/lava_example2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
endpoints:
- api-interface: tendermintrpc
chain-id: LAV1
network-address:
address: "127.0.0.1:2221"
node-urls:
- url: ws://127.0.0.1:26657/websocket
- url: http://127.0.0.1:26657
- api-interface: grpc
chain-id: LAV1
network-address:
address: "127.0.0.1:2221"
node-urls:
- url: 127.0.0.1:9090
- api-interface: rest
chain-id: LAV1
network-address:
address: "127.0.0.1:2221"
node-urls:
- url: http://127.0.0.1:1317
4 changes: 2 additions & 2 deletions protocol/chainlib/chain_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,8 +1175,8 @@ func TestMain(m *testing.M) {
listener := createRPCServer()
for {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := rpcclient.DialContext(ctx, listenerAddressHttp)
_, err2 := rpcclient.DialContext(ctx, listenerAddressWs)
_, err := rpcclient.DialContext(ctx, listenerAddressHttp, nil)
_, err2 := rpcclient.DialContext(ctx, listenerAddressWs, nil)
if err2 != nil {
utils.LavaFormatDebug("waiting for grpc server to launch")
continue
Expand Down
27 changes: 16 additions & 11 deletions protocol/chainlib/chainproxy/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ func (connector *Connector) numberOfUsedClients() int {
return int(atomic.LoadInt64(&connector.usedClients))
}

func (connector *Connector) getRpcClient(ctx context.Context, nodeUrl common.NodeUrl) (*rpcclient.Client, error) {
authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url)
// origin used for auth header in the websocket case
authHeaders := nodeUrl.GetAuthHeaders()
rpcClient, err := rpcclient.DialContext(ctx, authPathNodeUrl, authHeaders)
if err != nil {
return nil, err
}
nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader)
return rpcClient, nil
}

func (connector *Connector) createConnection(ctx context.Context, nodeUrl common.NodeUrl, currentNumberOfConnections int) (*rpcclient.Client, error) {
var rpcClient *rpcclient.Client
var err error
Expand All @@ -124,21 +136,13 @@ func (connector *Connector) createConnection(ctx context.Context, nodeUrl common
}
timeout := common.AverageWorldLatency * (1 + time.Duration(numberOfConnectionAttempts))
nctx, cancel := nodeUrl.LowerContextTimeoutWithDuration(ctx, timeout)
// add auth path
authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url)
rpcClient, err = rpcclient.DialContext(nctx, authPathNodeUrl)
// get rpcClient
rpcClient, err = connector.getRpcClient(nctx, nodeUrl)
if err != nil {
utils.LavaFormatWarning("Could not connect to the node, retrying", err, []utils.Attribute{
{Key: "Current Number Of Connections", Value: currentNumberOfConnections},
{Key: "Network Address", Value: authPathNodeUrl},
{Key: "Number Of Attempts", Value: numberOfConnectionAttempts},
{Key: "timeout", Value: timeout},
}...)
cancel()
continue
}
cancel()
nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader)
break
}

Expand Down Expand Up @@ -178,7 +182,8 @@ func (connector *Connector) increaseNumberOfClients(ctx context.Context, numberO
var err error
for connectionAttempt := 0; connectionAttempt < MaximumNumberOfParallelConnectionsAttempts; connectionAttempt++ {
nctx, cancel := connector.nodeUrl.LowerContextTimeoutWithDuration(ctx, common.AverageWorldLatency*2)
rpcClient, err = rpcclient.DialContext(nctx, connector.nodeUrl.Url)
// get rpcClient
rpcClient, err = connector.getRpcClient(nctx, connector.nodeUrl)
if err != nil {
utils.LavaFormatDebug(
"could no increase number of connections to the node jsonrpc connector, retrying",
Expand Down
90 changes: 89 additions & 1 deletion protocol/chainlib/chainproxy/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chainproxy

import (
"context"
"encoding/json"
"fmt"
"log"
"net"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/lavanet/lava/v4/utils"
pb_pkg "github.com/lavanet/lava/v4/x/spec/types"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
Expand Down Expand Up @@ -185,7 +187,7 @@ func TestMain(m *testing.M) {
listener := createRPCServer()
for {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := rpcclient.DialContext(ctx, listenerAddressTcp)
_, err := rpcclient.DialContext(ctx, listenerAddressTcp, nil)
if err != nil {
utils.LavaFormatDebug("waiting for grpc server to launch")
continue
Expand All @@ -199,3 +201,89 @@ func TestMain(m *testing.M) {
listener.Close()
os.Exit(code)
}

func TestConnectorWebsocket(t *testing.T) {
// Set up auth headers we expect
expectedAuthHeader := "Bearer test-token"

// Create WebSocket server with auth check
srv := &http.Server{
Addr: "localhost:0", // random available port
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check auth header
authHeader := r.Header.Get("Authorization")
if authHeader != expectedAuthHeader {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
fmt.Println("connection OK!")
// Upgrade to websocket
upgrader := websocket.Server{
Handler: websocket.Handler(func(ws *websocket.Conn) {
defer ws.Close()
// Simple echo server
for {
var msg string
err := websocket.Message.Receive(ws, &msg)
if err != nil {
break
}
websocket.Message.Send(ws, msg)
}
}),
}
upgrader.ServeHTTP(w, r)
}),
}

// Start server
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer listener.Close()

go srv.Serve(listener)
wsURL := "ws://" + listener.Addr().String()

// Create connector with auth config
ctx := context.Background()
nodeUrl := common.NodeUrl{
Url: wsURL,
AuthConfig: common.AuthConfig{
AuthHeaders: map[string]string{
"Authorization": expectedAuthHeader,
},
},
}

// Create connector
conn, err := NewConnector(ctx, numberOfClients, nodeUrl)
require.NoError(t, err)
defer conn.Close()

// Wait for connections to be established
for {
if len(conn.freeClients) == numberOfClients {
break
}
time.Sleep(10 * time.Millisecond)
}

// Get a client and test the connection
client, err := conn.GetRpc(ctx, true)
require.NoError(t, err)

// Test sending a message using CallContext
params := map[string]interface{}{
"test": "value",
}
id := json.RawMessage(`1`)
_, err = client.CallContext(ctx, id, "test_method", params, true, true)
require.NoError(t, err)

// Return the client
conn.ReturnRpc(client)

// Verify connection pool state
require.Equal(t, int64(0), conn.usedClients)
require.Equal(t, numberOfClients, len(conn.freeClients))
}
6 changes: 3 additions & 3 deletions protocol/chainlib/chainproxy/rpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*JsonrpcMessage, erro
//
// The client reconnects automatically if the connection is lost.
func Dial(rawurl string) (*Client, error) {
return DialContext(context.Background(), rawurl)
return DialContext(context.Background(), rawurl, nil)
}

// DialContext creates a new RPC client, just like Dial.
//
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
func DialContext(ctx context.Context, rawurl string) (*Client, error) {
func DialContext(ctx context.Context, rawurl string, wsHeaders map[string]string) (*Client, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
Expand All @@ -193,7 +193,7 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) {
case "http", "https":
return DialHTTP(rawurl)
case "ws", "wss":
return DialWebsocket(ctx, rawurl, "")
return DialWebsocket(ctx, rawurl, wsHeaders)
case "stdio":
return DialStdIO(ctx)
case "":
Expand Down
14 changes: 7 additions & 7 deletions protocol/chainlib/chainproxy/rpcclient/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ func parseOriginURL(origin string) (string, string, string, error) {

// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint using the provided dialer.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, origin)
func DialWebsocketWithDialer(ctx context.Context, endpoint string, dialer websocket.Dialer, headers map[string]string) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, headers)
if err != nil {
return nil, err
}
Expand All @@ -210,23 +210,23 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
//
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
func DialWebsocket(ctx context.Context, endpoint string, headers map[string]string) (*Client, error) {
dialer := websocket.Dialer{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
WriteBufferPool: wsBufferPool,
}
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
return DialWebsocketWithDialer(ctx, endpoint, dialer, headers)
}

func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
func wsClientHeaders(endpoint string, headers map[string]string) (string, http.Header, error) {
endpointURL, err := url.Parse(endpoint)
if err != nil {
return endpoint, nil, err
}
header := make(http.Header)
if origin != "" {
header.Add("origin", origin)
for headerKey, headerValue := range headers {
header.Add(headerKey, headerValue)
}
if endpointURL.User != nil {
b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
Expand Down
4 changes: 4 additions & 0 deletions protocol/common/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func (nurl *NodeUrl) UrlStr() string {
return parsedURL.String()
}

func (url *NodeUrl) GetAuthHeaders() map[string]string {
return url.AuthConfig.AuthHeaders
}

func (url *NodeUrl) SetAuthHeaders(ctx context.Context, headerSetter func(string, string)) {
for header, headerValue := range url.AuthConfig.AuthHeaders {
headerSetter(header, headerValue)
Expand Down
2 changes: 1 addition & 1 deletion protocol/integration/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (m *mockConsumerStateTracker) RegisterForSpecUpdates(ctx context.Context, s
return nil
}

func (m *mockConsumerStateTracker) RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus) {
func (m *mockConsumerStateTracker) RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus, bool) {
}

func (m *mockConsumerStateTracker) RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error {
Expand Down
1 change: 1 addition & 0 deletions protocol/lavasession/consumer_session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func (csm *ConsumerSessionManager) RPCEndpoint() RPCEndpoint {
}

func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList map[uint64]*ConsumerSessionsWithProvider) error {
utils.LavaFormatDebug("UpdateAllProviders", utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "pairingListLen", Value: len(pairingList)})
pairingListLength := len(pairingList)
// TODO: we can block updating until some of the probing is done, this can prevent failed attempts on epoch change when we have no information on the providers,
// and all of them are new (less effective on big pairing lists or a process that runs for a few epochs)
Expand Down
17 changes: 9 additions & 8 deletions protocol/rpcconsumer/consumer_state_tracker_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions protocol/rpcconsumer/rpcconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type ConsumerStateTrackerInf interface {
RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf)
RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProvidersList []*lavasession.RPCProviderEndpoint)
RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error
RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus)
RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus, bool)
RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error
TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, conflictHandler common.ConflictHandlerInterface) error
GetConsumerPolicy(ctx context.Context, consumerAddress, chainID string) (*plantypes.Policy, error)
Expand Down Expand Up @@ -348,6 +348,15 @@ func (rpcc *RPCConsumer) CreateConsumerEndpoint(
return nil, err
}

// Filter the relevant static providers
relevantStaticProviderList := []*lavasession.RPCProviderEndpoint{}
for _, staticProvider := range options.staticProvidersList {
if staticProvider.ChainID == rpcEndpoint.ChainID {
relevantStaticProviderList = append(relevantStaticProviderList, staticProvider)
}
}
staticProvidersActive := len(relevantStaticProviderList) > 0

_, averageBlockTime, _, _ := chainParser.ChainBlockStats()
var optimizer *provideroptimizer.ProviderOptimizer
var consumerConsistency *ConsumerConsistency
Expand Down Expand Up @@ -387,7 +396,7 @@ func (rpcc *RPCConsumer) CreateConsumerEndpoint(
return utils.LavaFormatError("failed loading finalization consensus", err, utils.LogAttr("endpoint", rpcEndpoint.Key()))
}
if !loaded { // when creating new finalization consensus instance we need to register it to updates
consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus)
consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus, staticProvidersActive)
}
return nil
}
Expand Down
Loading

0 comments on commit 7e6d535

Please sign in to comment.