Skip to content

Commit

Permalink
Add credential to OpenDataChannel request
Browse files Browse the repository at this point in the history
  • Loading branch information
Yangtao-Hua committed Nov 6, 2024
1 parent 7b5e24f commit 4c4508b
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 34 deletions.
46 changes: 41 additions & 5 deletions src/communicator/mocks/IWebSocketChannel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 64 additions & 3 deletions src/communicator/websocketchannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ package communicator

import (
"errors"
"net/http"
"net/url"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/session-manager-plugin/src/config"
"github.com/aws/session-manager-plugin/src/log"
"github.com/aws/session-manager-plugin/src/websocketutil"
Expand All @@ -27,7 +30,7 @@ import (

// IWebSocketChannel is the interface for DataChannel.
type IWebSocketChannel interface {
Initialize(log log.T, channelUrl string, channelToken string)
Initialize(log log.T, channelUrl string, channelToken string, region string, signer *v4.Signer)
Open(log log.T) error
Close(log log.T) error
SendMessage(log log.T, input []byte, inputType int) error
Expand All @@ -49,6 +52,8 @@ type WebSocketChannel struct {
writeLock *sync.Mutex
Connection *websocket.Conn
ChannelToken string
Region string
Signer *v4.Signer
}

// GetChannelToken gets the channel token
Expand Down Expand Up @@ -77,9 +82,11 @@ func (webSocketChannel *WebSocketChannel) SetOnMessage(onMessageHandler func([]b
}

// Initialize initializes websocket channel fields
func (webSocketChannel *WebSocketChannel) Initialize(log log.T, channelUrl string, channelToken string) {
func (webSocketChannel *WebSocketChannel) Initialize(log log.T, channelUrl string, channelToken string, region string, signer *v4.Signer) {
webSocketChannel.ChannelToken = channelToken
webSocketChannel.Url = channelUrl
webSocketChannel.Region = region
webSocketChannel.Signer = signer
}

// StartPings starts the pinging process to keep the websocket channel alive.
Expand Down Expand Up @@ -121,6 +128,47 @@ func (webSocketChannel *WebSocketChannel) SendMessage(log log.T, input []byte, i
return err
}

// getV4SignatureHeader gets the signed header.
func (webSocketChannel *WebSocketChannel) getV4SignatureHeader(log log.T, Url string) (http.Header, error) {
request, err := http.NewRequest("GET", Url, nil)

if webSocketChannel.Signer != nil {
_, err = webSocketChannel.Signer.Sign(request, nil, config.ServiceName, webSocketChannel.Region, time.Now())
if err != nil {
log.Errorf("Failed to sign websocket, %v", err)
}
}
return request.Header, err
}

// isPresignedURL check is the url presigned.
func isPresignedURL(rawURL string) (bool, error) {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return false, err
}

queryParams := parsedURL.Query()

presignedURLParams := []string{
"X-Amz-Algorithm",
"X-Amz-Credential",
"X-Amz-Date",
"X-Amz-Expires",
"X-Amz-SignedHeaders",
"X-Amz-Signature",
"X-Amz-Security-Token",
}

for _, param := range presignedURLParams {
if _, exists := queryParams[param]; exists {
return true, nil
}
}

return false, nil
}

// Close closes the corresponding connection.
func (webSocketChannel *WebSocketChannel) Close(log log.T) error {

Expand All @@ -139,9 +187,22 @@ func (webSocketChannel *WebSocketChannel) Close(log log.T) error {
func (webSocketChannel *WebSocketChannel) Open(log log.T) error {
// initialize the write mutex
webSocketChannel.writeLock = &sync.Mutex{}
presigned, err := isPresignedURL(webSocketChannel.Url)
if err != nil {
return err
}

var header http.Header
if !presigned {
header, err = webSocketChannel.getV4SignatureHeader(log, webSocketChannel.Url)
if err != nil {
log.Errorf("Failed to get the v4 signature, %v", err)
}
}

ws, err := websocketutil.NewWebsocketUtil(log, nil).OpenConnection(webSocketChannel.Url)
ws, err := websocketutil.NewWebsocketUtil(log, nil).OpenConnection(webSocketChannel.Url, header)
if err != nil {
log.Errorf("Failed to open WebSocket connection: %v", err)
return err
}
webSocketChannel.Connection = ws
Expand Down
50 changes: 49 additions & 1 deletion src/communicator/websocketchannel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/session-manager-plugin/src/log"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
Expand All @@ -35,6 +38,8 @@ var (
defaultStreamUrl = "streamUrl"
defaultError = errors.New("Default Error")
defaultMessage = []byte("Default Message")
defaultRegion = "us-east-1"
mockSigner = &v4.Signer{Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")}
)

type ErrorCallbackWrapper struct {
Expand Down Expand Up @@ -141,10 +146,11 @@ func TestWebsocketChannel_SetOnMessage(t *testing.T) {
func TestWebsocketchannel_Initialize(t *testing.T) {
t.Log("Starting test: webSocketChannel.Initialize")
channel := &WebSocketChannel{}
channel.Initialize(mockLogger, defaultStreamUrl, defaultChannelToken)
channel.Initialize(mockLogger, defaultStreamUrl, defaultChannelToken, defaultRegion, mockSigner)

assert.Equal(t, defaultStreamUrl, channel.Url)
assert.Equal(t, defaultChannelToken, channel.ChannelToken)
assert.Equal(t, mockSigner, channel.Signer)
}

func TestOpenCloseWebSocketChannel(t *testing.T) {
Expand All @@ -169,6 +175,48 @@ func TestOpenCloseWebSocketChannel(t *testing.T) {
t.Log("Ending test: TestOpenCloseWebSocketChannel")
}

func TestOpenWebSocketChannelWithPresignedURL(t *testing.T) {
t.Log("Starting test: TestOpenWebSocketChannelWithPresignedURL")
srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested))
u, _ := url.Parse(srv.URL)
u.Scheme = "ws"
var log = log.NewMockLog()

query := u.Query()
query.Set("X-Amz-Signature", "SAMPLE_SIGNATURE")
u.RawQuery = query.Encode()

websocketchannel := WebSocketChannel{
Url: u.String(),
Signer: nil,
}

err := websocketchannel.Open(log)
assert.Nil(t, err, "Error opening the websocket connection.")
assert.NotNil(t, websocketchannel.Connection, "Open connection failed.")
assert.True(t, websocketchannel.IsOpen, "IsOpen is not set to true.")
assert.True(t, strings.Contains(websocketchannel.Url, "SAMPLE_SIGNATURE"),
"URL not included signature as expected")

err = websocketchannel.Close(log)
assert.Nil(t, err, "Error closing the websocket connection.")
assert.False(t, websocketchannel.IsOpen, "IsOpen is not set to false.")
t.Log("Ending test: TestOpenCloseWebSocketChannel")
}

func TestOpenWebSocketChannelWithInvalidURL(t *testing.T) {
t.Log("Starting test: TestOpenWebSocketChannelWithInvalidURL")
var log = log.NewMockLog()
websocketchannel := WebSocketChannel{
Url: "invalid_url",
Signer: nil,
}

err := websocketchannel.Open(log)
assert.NotNil(t, err, "malformed ws or wss URL.")
assert.Nil(t, websocketchannel.Connection, "Open connection failed.")
}

func TestReadWriteTextToWebSocketChannel(t *testing.T) {
t.Log("Starting test: TestReadWriteWebSocketChannel ")
srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested))
Expand Down
1 change: 1 addition & 0 deletions src/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package config
import "time"

const (
ServiceName = "ssmmessages"
RolePublishSubscribe = "publish_subscribe"
MessageSchemaVersion = "1.0"
DefaultTransmissionTimeout = 200 * time.Millisecond
Expand Down
Loading

0 comments on commit 4c4508b

Please sign in to comment.