diff --git a/protocol/chainlib/chain_router_test.go b/protocol/chainlib/chain_router_test.go index db1b9fa8d6..e2fe5e1c0a 100644 --- a/protocol/chainlib/chain_router_test.go +++ b/protocol/chainlib/chain_router_test.go @@ -1175,8 +1175,8 @@ func TestMain(m *testing.M) { listener := createRPCServer() for { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := rpcclient.DialContext(ctx, listenerAddressHttp) - _, err2 := rpcclient.DialContext(ctx, listenerAddressWs) + _, err := rpcclient.DialContext(ctx, listenerAddressHttp, nil) + _, err2 := rpcclient.DialContext(ctx, listenerAddressWs, nil) if err2 != nil { utils.LavaFormatDebug("waiting for grpc server to launch") continue diff --git a/protocol/chainlib/chainproxy/connector.go b/protocol/chainlib/chainproxy/connector.go index 307f86915a..f5c7d51ff7 100644 --- a/protocol/chainlib/chainproxy/connector.go +++ b/protocol/chainlib/chainproxy/connector.go @@ -107,6 +107,18 @@ func (connector *Connector) numberOfUsedClients() int { return int(atomic.LoadInt64(&connector.usedClients)) } +func (connector *Connector) getRpcClient(ctx context.Context, nodeUrl common.NodeUrl) (*rpcclient.Client, error) { + authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url) + // origin used for auth header in the websocket case + authHeaders := nodeUrl.GetAuthHeaders() + rpcClient, err := rpcclient.DialContext(ctx, authPathNodeUrl, authHeaders) + if err != nil { + return nil, err + } + nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader) + return rpcClient, nil +} + func (connector *Connector) createConnection(ctx context.Context, nodeUrl common.NodeUrl, currentNumberOfConnections int) (*rpcclient.Client, error) { var rpcClient *rpcclient.Client var err error @@ -124,21 +136,13 @@ func (connector *Connector) createConnection(ctx context.Context, nodeUrl common } timeout := common.AverageWorldLatency * (1 + time.Duration(numberOfConnectionAttempts)) nctx, cancel := nodeUrl.LowerContextTimeoutWithDuration(ctx, timeout) - // add auth path - authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url) - rpcClient, err = rpcclient.DialContext(nctx, authPathNodeUrl) + // get rpcClient + rpcClient, err = connector.getRpcClient(nctx, nodeUrl) if err != nil { - utils.LavaFormatWarning("Could not connect to the node, retrying", err, []utils.Attribute{ - {Key: "Current Number Of Connections", Value: currentNumberOfConnections}, - {Key: "Network Address", Value: authPathNodeUrl}, - {Key: "Number Of Attempts", Value: numberOfConnectionAttempts}, - {Key: "timeout", Value: timeout}, - }...) cancel() continue } cancel() - nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader) break } @@ -178,7 +182,8 @@ func (connector *Connector) increaseNumberOfClients(ctx context.Context, numberO var err error for connectionAttempt := 0; connectionAttempt < MaximumNumberOfParallelConnectionsAttempts; connectionAttempt++ { nctx, cancel := connector.nodeUrl.LowerContextTimeoutWithDuration(ctx, common.AverageWorldLatency*2) - rpcClient, err = rpcclient.DialContext(nctx, connector.nodeUrl.Url) + // get rpcClient + rpcClient, err = connector.getRpcClient(nctx, connector.nodeUrl) if err != nil { utils.LavaFormatDebug( "could no increase number of connections to the node jsonrpc connector, retrying", diff --git a/protocol/chainlib/chainproxy/connector_test.go b/protocol/chainlib/chainproxy/connector_test.go index e408ae2062..26a8100425 100644 --- a/protocol/chainlib/chainproxy/connector_test.go +++ b/protocol/chainlib/chainproxy/connector_test.go @@ -2,6 +2,7 @@ package chainproxy import ( "context" + "encoding/json" "fmt" "log" "net" @@ -16,6 +17,7 @@ import ( "github.com/lavanet/lava/v4/utils" pb_pkg "github.com/lavanet/lava/v4/x/spec/types" "github.com/stretchr/testify/require" + "golang.org/x/net/websocket" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -185,7 +187,7 @@ func TestMain(m *testing.M) { listener := createRPCServer() for { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := rpcclient.DialContext(ctx, listenerAddressTcp) + _, err := rpcclient.DialContext(ctx, listenerAddressTcp, nil) if err != nil { utils.LavaFormatDebug("waiting for grpc server to launch") continue @@ -199,3 +201,89 @@ func TestMain(m *testing.M) { listener.Close() os.Exit(code) } + +func TestConnectorWebsocket(t *testing.T) { + // Set up auth headers we expect + expectedAuthHeader := "Bearer test-token" + + // Create WebSocket server with auth check + srv := &http.Server{ + Addr: "localhost:0", // random available port + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check auth header + authHeader := r.Header.Get("Authorization") + if authHeader != expectedAuthHeader { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + fmt.Println("connection OK!") + // Upgrade to websocket + upgrader := websocket.Server{ + Handler: websocket.Handler(func(ws *websocket.Conn) { + defer ws.Close() + // Simple echo server + for { + var msg string + err := websocket.Message.Receive(ws, &msg) + if err != nil { + break + } + websocket.Message.Send(ws, msg) + } + }), + } + upgrader.ServeHTTP(w, r) + }), + } + + // Start server + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer listener.Close() + + go srv.Serve(listener) + wsURL := "ws://" + listener.Addr().String() + + // Create connector with auth config + ctx := context.Background() + nodeUrl := common.NodeUrl{ + Url: wsURL, + AuthConfig: common.AuthConfig{ + AuthHeaders: map[string]string{ + "Authorization": expectedAuthHeader, + }, + }, + } + + // Create connector + conn, err := NewConnector(ctx, numberOfClients, nodeUrl) + require.NoError(t, err) + defer conn.Close() + + // Wait for connections to be established + for { + if len(conn.freeClients) == numberOfClients { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Get a client and test the connection + client, err := conn.GetRpc(ctx, true) + require.NoError(t, err) + + // Test sending a message using CallContext + params := map[string]interface{}{ + "test": "value", + } + id := json.RawMessage(`1`) + _, err = client.CallContext(ctx, id, "test_method", params, true, true) + require.NoError(t, err) + + // Return the client + conn.ReturnRpc(client) + + // Verify connection pool state + require.Equal(t, int64(0), conn.usedClients) + require.Equal(t, numberOfClients, len(conn.freeClients)) +} diff --git a/protocol/chainlib/chainproxy/rpcclient/client.go b/protocol/chainlib/chainproxy/rpcclient/client.go index c745f82ee2..15fcdb9b73 100644 --- a/protocol/chainlib/chainproxy/rpcclient/client.go +++ b/protocol/chainlib/chainproxy/rpcclient/client.go @@ -177,14 +177,14 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*JsonrpcMessage, erro // // The client reconnects automatically if the connection is lost. func Dial(rawurl string) (*Client, error) { - return DialContext(context.Background(), rawurl) + return DialContext(context.Background(), rawurl, nil) } // DialContext creates a new RPC client, just like Dial. // // The context is used to cancel or time out the initial connection establishment. It does // not affect subsequent interactions with the client. -func DialContext(ctx context.Context, rawurl string) (*Client, error) { +func DialContext(ctx context.Context, rawurl string, wsHeaders map[string]string) (*Client, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err @@ -193,7 +193,7 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) { case "http", "https": return DialHTTP(rawurl) case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") + return DialWebsocket(ctx, rawurl, wsHeaders) case "stdio": return DialStdIO(ctx) case "": diff --git a/protocol/chainlib/chainproxy/rpcclient/websocket.go b/protocol/chainlib/chainproxy/rpcclient/websocket.go index 81566ffe0f..680d37167a 100755 --- a/protocol/chainlib/chainproxy/rpcclient/websocket.go +++ b/protocol/chainlib/chainproxy/rpcclient/websocket.go @@ -187,8 +187,8 @@ func parseOriginURL(origin string) (string, string, string, error) { // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. -func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { - endpoint, header, err := wsClientHeaders(endpoint, origin) +func DialWebsocketWithDialer(ctx context.Context, endpoint string, dialer websocket.Dialer, headers map[string]string) (*Client, error) { + endpoint, header, err := wsClientHeaders(endpoint, headers) if err != nil { return nil, err } @@ -210,23 +210,23 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale // // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. -func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { +func DialWebsocket(ctx context.Context, endpoint string, headers map[string]string) (*Client, error) { dialer := websocket.Dialer{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, WriteBufferPool: wsBufferPool, } - return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) + return DialWebsocketWithDialer(ctx, endpoint, dialer, headers) } -func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { +func wsClientHeaders(endpoint string, headers map[string]string) (string, http.Header, error) { endpointURL, err := url.Parse(endpoint) if err != nil { return endpoint, nil, err } header := make(http.Header) - if origin != "" { - header.Add("origin", origin) + for headerKey, headerValue := range headers { + header.Add(headerKey, headerValue) } if endpointURL.User != nil { b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) diff --git a/protocol/common/endpoints.go b/protocol/common/endpoints.go index 2379512708..098b301b57 100644 --- a/protocol/common/endpoints.go +++ b/protocol/common/endpoints.go @@ -91,6 +91,10 @@ func (nurl *NodeUrl) UrlStr() string { return parsedURL.String() } +func (url *NodeUrl) GetAuthHeaders() map[string]string { + return url.AuthConfig.AuthHeaders +} + func (url *NodeUrl) SetAuthHeaders(ctx context.Context, headerSetter func(string, string)) { for header, headerValue := range url.AuthConfig.AuthHeaders { headerSetter(header, headerValue)