Skip to content

Commit

Permalink
Merge pull request #1912 from lavanet/PRT-add-auth-headers-to-websock…
Browse files Browse the repository at this point in the history
…et-connection

feat: PRT - Add Auth headers to Websocket connection
  • Loading branch information
ranlavanet authored Jan 16, 2025
2 parents 8c0c3ce + f8bd4a5 commit ec071f0
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 24 deletions.
4 changes: 2 additions & 2 deletions protocol/chainlib/chain_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,8 +1175,8 @@ func TestMain(m *testing.M) {
listener := createRPCServer()
for {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := rpcclient.DialContext(ctx, listenerAddressHttp)
_, err2 := rpcclient.DialContext(ctx, listenerAddressWs)
_, err := rpcclient.DialContext(ctx, listenerAddressHttp, nil)
_, err2 := rpcclient.DialContext(ctx, listenerAddressWs, nil)
if err2 != nil {
utils.LavaFormatDebug("waiting for grpc server to launch")
continue
Expand Down
27 changes: 16 additions & 11 deletions protocol/chainlib/chainproxy/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ func (connector *Connector) numberOfUsedClients() int {
return int(atomic.LoadInt64(&connector.usedClients))
}

func (connector *Connector) getRpcClient(ctx context.Context, nodeUrl common.NodeUrl) (*rpcclient.Client, error) {
authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url)
// origin used for auth header in the websocket case
authHeaders := nodeUrl.GetAuthHeaders()
rpcClient, err := rpcclient.DialContext(ctx, authPathNodeUrl, authHeaders)
if err != nil {
return nil, err
}
nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader)
return rpcClient, nil
}

func (connector *Connector) createConnection(ctx context.Context, nodeUrl common.NodeUrl, currentNumberOfConnections int) (*rpcclient.Client, error) {
var rpcClient *rpcclient.Client
var err error
Expand All @@ -124,21 +136,13 @@ func (connector *Connector) createConnection(ctx context.Context, nodeUrl common
}
timeout := common.AverageWorldLatency * (1 + time.Duration(numberOfConnectionAttempts))
nctx, cancel := nodeUrl.LowerContextTimeoutWithDuration(ctx, timeout)
// add auth path
authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url)
rpcClient, err = rpcclient.DialContext(nctx, authPathNodeUrl)
// get rpcClient
rpcClient, err = connector.getRpcClient(nctx, nodeUrl)
if err != nil {
utils.LavaFormatWarning("Could not connect to the node, retrying", err, []utils.Attribute{
{Key: "Current Number Of Connections", Value: currentNumberOfConnections},
{Key: "Network Address", Value: authPathNodeUrl},
{Key: "Number Of Attempts", Value: numberOfConnectionAttempts},
{Key: "timeout", Value: timeout},
}...)
cancel()
continue
}
cancel()
nodeUrl.SetAuthHeaders(ctx, rpcClient.SetHeader)
break
}

Expand Down Expand Up @@ -178,7 +182,8 @@ func (connector *Connector) increaseNumberOfClients(ctx context.Context, numberO
var err error
for connectionAttempt := 0; connectionAttempt < MaximumNumberOfParallelConnectionsAttempts; connectionAttempt++ {
nctx, cancel := connector.nodeUrl.LowerContextTimeoutWithDuration(ctx, common.AverageWorldLatency*2)
rpcClient, err = rpcclient.DialContext(nctx, connector.nodeUrl.Url)
// get rpcClient
rpcClient, err = connector.getRpcClient(nctx, connector.nodeUrl)
if err != nil {
utils.LavaFormatDebug(
"could no increase number of connections to the node jsonrpc connector, retrying",
Expand Down
90 changes: 89 additions & 1 deletion protocol/chainlib/chainproxy/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chainproxy

import (
"context"
"encoding/json"
"fmt"
"log"
"net"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/lavanet/lava/v4/utils"
pb_pkg "github.com/lavanet/lava/v4/x/spec/types"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
Expand Down Expand Up @@ -185,7 +187,7 @@ func TestMain(m *testing.M) {
listener := createRPCServer()
for {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := rpcclient.DialContext(ctx, listenerAddressTcp)
_, err := rpcclient.DialContext(ctx, listenerAddressTcp, nil)
if err != nil {
utils.LavaFormatDebug("waiting for grpc server to launch")
continue
Expand All @@ -199,3 +201,89 @@ func TestMain(m *testing.M) {
listener.Close()
os.Exit(code)
}

func TestConnectorWebsocket(t *testing.T) {
// Set up auth headers we expect
expectedAuthHeader := "Bearer test-token"

// Create WebSocket server with auth check
srv := &http.Server{
Addr: "localhost:0", // random available port
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check auth header
authHeader := r.Header.Get("Authorization")
if authHeader != expectedAuthHeader {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
fmt.Println("connection OK!")
// Upgrade to websocket
upgrader := websocket.Server{
Handler: websocket.Handler(func(ws *websocket.Conn) {
defer ws.Close()
// Simple echo server
for {
var msg string
err := websocket.Message.Receive(ws, &msg)
if err != nil {
break
}
websocket.Message.Send(ws, msg)
}
}),
}
upgrader.ServeHTTP(w, r)
}),
}

// Start server
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer listener.Close()

go srv.Serve(listener)
wsURL := "ws://" + listener.Addr().String()

// Create connector with auth config
ctx := context.Background()
nodeUrl := common.NodeUrl{
Url: wsURL,
AuthConfig: common.AuthConfig{
AuthHeaders: map[string]string{
"Authorization": expectedAuthHeader,
},
},
}

// Create connector
conn, err := NewConnector(ctx, numberOfClients, nodeUrl)
require.NoError(t, err)
defer conn.Close()

// Wait for connections to be established
for {
if len(conn.freeClients) == numberOfClients {
break
}
time.Sleep(10 * time.Millisecond)
}

// Get a client and test the connection
client, err := conn.GetRpc(ctx, true)
require.NoError(t, err)

// Test sending a message using CallContext
params := map[string]interface{}{
"test": "value",
}
id := json.RawMessage(`1`)
_, err = client.CallContext(ctx, id, "test_method", params, true, true)
require.NoError(t, err)

// Return the client
conn.ReturnRpc(client)

// Verify connection pool state
require.Equal(t, int64(0), conn.usedClients)
require.Equal(t, numberOfClients, len(conn.freeClients))
}
6 changes: 3 additions & 3 deletions protocol/chainlib/chainproxy/rpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*JsonrpcMessage, erro
//
// The client reconnects automatically if the connection is lost.
func Dial(rawurl string) (*Client, error) {
return DialContext(context.Background(), rawurl)
return DialContext(context.Background(), rawurl, nil)
}

// DialContext creates a new RPC client, just like Dial.
//
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
func DialContext(ctx context.Context, rawurl string) (*Client, error) {
func DialContext(ctx context.Context, rawurl string, wsHeaders map[string]string) (*Client, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
Expand All @@ -193,7 +193,7 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) {
case "http", "https":
return DialHTTP(rawurl)
case "ws", "wss":
return DialWebsocket(ctx, rawurl, "")
return DialWebsocket(ctx, rawurl, wsHeaders)
case "stdio":
return DialStdIO(ctx)
case "":
Expand Down
14 changes: 7 additions & 7 deletions protocol/chainlib/chainproxy/rpcclient/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ func parseOriginURL(origin string) (string, string, string, error) {

// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint using the provided dialer.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, origin)
func DialWebsocketWithDialer(ctx context.Context, endpoint string, dialer websocket.Dialer, headers map[string]string) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, headers)
if err != nil {
return nil, err
}
Expand All @@ -210,23 +210,23 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
//
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
func DialWebsocket(ctx context.Context, endpoint string, headers map[string]string) (*Client, error) {
dialer := websocket.Dialer{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
WriteBufferPool: wsBufferPool,
}
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
return DialWebsocketWithDialer(ctx, endpoint, dialer, headers)
}

func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
func wsClientHeaders(endpoint string, headers map[string]string) (string, http.Header, error) {
endpointURL, err := url.Parse(endpoint)
if err != nil {
return endpoint, nil, err
}
header := make(http.Header)
if origin != "" {
header.Add("origin", origin)
for headerKey, headerValue := range headers {
header.Add(headerKey, headerValue)
}
if endpointURL.User != nil {
b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
Expand Down
4 changes: 4 additions & 0 deletions protocol/common/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func (nurl *NodeUrl) UrlStr() string {
return parsedURL.String()
}

func (url *NodeUrl) GetAuthHeaders() map[string]string {
return url.AuthConfig.AuthHeaders
}

func (url *NodeUrl) SetAuthHeaders(ctx context.Context, headerSetter func(string, string)) {
for header, headerValue := range url.AuthConfig.AuthHeaders {
headerSetter(header, headerValue)
Expand Down

0 comments on commit ec071f0

Please sign in to comment.