Skip to content

Commit

Permalink
Added support for Websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobm-splunk authored and daveshanley committed Apr 5, 2024
1 parent a7d6a8c commit 5e9102a
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 73 deletions.
15 changes: 15 additions & 0 deletions cmd/handle_http_traffic.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@ func handleHttpTraffic(wiretapConfig *shared.WiretapConfiguration, wtService *da
wtService.HandleHttpRequest(requestModel)
}

handleWebsocket := func(w http.ResponseWriter, r *http.Request) {
id, _ := uuid.NewUUID()
requestModel := &model.Request{
Id: &id,
HttpRequest: r,
HttpResponseWriter: w,
}
wtService.HandleWebsocketRequest(requestModel)
}

// create a new mux.
mux := http.NewServeMux()

// handle the index
mux.HandleFunc("/", handleTraffic)

// Handle Websockets
for websocket := range wiretapConfig.WebsocketConfigs {
mux.HandleFunc(websocket, handleWebsocket)
}

pterm.Info.Println(pterm.LightMagenta(fmt.Sprintf("API Gateway UI booting on port %s...", wiretapConfig.Port)))

var httpErr error
Expand Down
24 changes: 21 additions & 3 deletions cmd/root_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ var (
printLoadedRedirectAllowList(config.RedirectAllowList)
}

if len(config.WebsocketConfigs) > 0 {
for _, config := range config.WebsocketConfigs {
if config.VerifyCert == nil {
config.VerifyCert = func() *bool { b := true; return &b }()
}
}

printLoadedWebsockets(config.WebsocketConfigs)
}

// static headers
if config.Headers != nil && len(config.Headers.DropHeaders) > 0 {
pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders),
Expand Down Expand Up @@ -625,8 +635,7 @@ func Execute(version, commit, date string, fs embed.FS) {
rootCmd.Flags().IntP("hard-validation-code", "q", 400, "Set a custom http error code for non-compliant requests when using the hard-error flag")
rootCmd.Flags().IntP("hard-validation-return-code", "y", 502, "Set a custom http error code for non-compliant responses when using the hard-error flag")
rootCmd.Flags().BoolP("mock-mode", "x", false, "Run in mock mode, responses are mocked and no traffic is sent to the target API (requires OpenAPI spec)")
rootCmd.Flags().StringP("config", "c", "",
"Location of wiretap configuration file to use (default is .wiretap in current directory)")
rootCmd.Flags().StringP("config", "c", "", "Location of wiretap configuration file to use (default is .wiretap in current directory)")
rootCmd.Flags().StringP("base", "b", "", "Set a base path to resolve relative file references from, or a overriding base URL to resolve remote references from")
rootCmd.Flags().BoolP("debug", "l", false, "Enable debug logging")
rootCmd.Flags().StringP("har", "z", "", "Load a HAR file instead of sniffing traffic")
Expand Down Expand Up @@ -706,11 +715,20 @@ func printLoadedIgnoreRedirectPaths(ignoreRedirects []string) {
}

func printLoadedRedirectAllowList(allowRedirects []string) {
pterm.Info.Printf("Loaded %d allows listed redirect %s :\n", len(allowRedirects),
pterm.Info.Printf("Loaded %d allows listed redirect %s:\n", len(allowRedirects),
shared.Pluralize(len(allowRedirects), "path", "paths"))

for _, x := range allowRedirects {
pterm.Printf("🐵 Paths matching '%s' will always follow redirects, regardless of ignoreRedirect settings\n", pterm.LightCyan(x))
}
pterm.Println()
}

func printLoadedWebsockets(websockets map[string]*shared.WiretapWebsocketConfig) {
pterm.Info.Printf("Loaded %d %s: \n", len(websockets), shared.Pluralize(len(websockets), "websocket", "websockets"))

for websocket := range websockets {
pterm.Printf("🔌 Paths prefixed '%s' will be managed as a websocket\n", pterm.LightCyan(websocket))
}
pterm.Println()
}
231 changes: 205 additions & 26 deletions daemon/handle_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package daemon

import (
"crypto/tls"
_ "embed"
"fmt"
"github.com/gorilla/websocket"
"io"
"net/http"
"os"
Expand Down Expand Up @@ -99,32 +101,7 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
}
}

var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}
dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request)

newReq := CloneExistingRequest(CloneRequest{
Request: request.HttpRequest,
Expand Down Expand Up @@ -238,8 +215,210 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
_, _ = request.HttpResponseWriter.Write(body)
}

var gorillaDropHeaders = []string{
// Gorilla fills in the following headers, and complains if they are already present
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Protocol",
"Sec-Websocket-Extensions",
}

func (ws *WiretapService) handleWebsocketRequest(request *model.Request) {

configStore, _ := ws.controlsStore.Get(shared.ConfigKey)
config := configStore.(*shared.WiretapConfiguration)

// Get the Websocket Configuration
websocketUrl := request.HttpRequest.URL.String()
websocketConfig, ok := config.WebsocketConfigs[websocketUrl]
if !ok {
ws.config.Logger.Error(fmt.Sprintf("Unable to find websocket config for URL: %s", websocketUrl))
}

// There's nothing to do if we're in mock mode
if config.MockMode {
return
}

upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// Upgrade the connection from the client to open a websocket connection
clientConn, err := upgrader.Upgrade(request.HttpResponseWriter, request.HttpRequest, nil)
if err != nil {
ws.config.Logger.Error("Unable to upgrade websocket connection")
return
}
defer func(clientConn *websocket.Conn) {
_ = clientConn.Close()
}(clientConn)

if config.Headers == nil || len(config.Headers.DropHeaders) == 0 {
config.Headers = &shared.WiretapHeaderConfig{
DropHeaders: []string{},
}
}

// Get the updated headers and auth
dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request)

dropHeaders = append(dropHeaders, gorillaDropHeaders...)
dropHeaders = append(dropHeaders, websocketConfig.DropHeaders...)

// Determine the correct websocket protocol based on redirect protocol
var protocol string
if config.RedirectProtocol == "https" {
protocol = "wss"
} else if config.RedirectProtocol == "http" {
protocol = "ws"
} else if config.RedirectProtocol != "wss" && config.RedirectProtocol != "ws" {
config.Logger.Error(fmt.Sprintf("Unsupported Redirect Protocol: %s", config.RedirectProtocol))
return
}

// Create a new request, which fills in the URL and other information
newRequest := CloneExistingRequest(CloneRequest{
Request: request.HttpRequest,
Protocol: protocol,
Host: config.RedirectHost,
BasePath: config.RedirectBasePath,
Port: config.RedirectPort,
DropHeaders: dropHeaders,
InjectHeaders: injectHeaders,
Auth: auth,
Variables: config.CompiledVariables,
})

// Open a new websocket connection with the server
dialer := *websocket.DefaultDialer
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert}
serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header)
if err != nil {
ws.config.Logger.Error("Unable to create server connection")
return
}
defer func(serverConn *websocket.Conn) {
_ = serverConn.Close()
}(serverConn)

// Create sentinel channels
clientSentinel := make(chan struct{})
serverSentinel := make(chan struct{})

// Go-Routine for communication between Client -> Server
go func() {
defer close(clientSentinel)

for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
_ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}

err = serverConn.WriteMessage(messageType, message)
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
return
}
}
}()

// Go-Routine for communication between Server -> Client
go func() {
defer close(serverSentinel)

for {
messageType, message, err := serverConn.ReadMessage()
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
_ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}

err = clientConn.WriteMessage(messageType, message)
if err != nil {
closeCode, isUnexpected := getCloseCode(err)
logWebsocketClose(config, closeCode, isUnexpected)
return
}
}
}()

// Loop until at least one of our sentinel channels have been closed
for {
select {
case <-clientSentinel:
return
case <-serverSentinel:
return
}
}
}

func setCORSHeaders(headers map[string][]string) {
headers["Access-Control-Allow-Headers"] = []string{"*"}
headers["Access-Control-Allow-Origin"] = []string{"*"}
headers["Access-Control-Allow-Methods"] = []string{"OPTIONS,POST,GET,DELETE,PATCH,PUT"}
}

func getCloseCode(err error) (int, bool) {
unexpectedClose := websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
)

if ce, ok := err.(*websocket.CloseError); ok {
return ce.Code, unexpectedClose
}
return -1, unexpectedClose
}

func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) {
if isUnexpected {
config.Logger.Error(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode))
} else {
config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode))
}
}

func (ws *WiretapService) getHeadersAndAuth(config *shared.WiretapConfiguration, request *model.Request) ([]string, map[string]string, string) {
var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}

return dropHeaders, injectHeaders, auth
}
5 changes: 4 additions & 1 deletion daemon/wiretap_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func (ws *WiretapService) HandleServiceRequest(request *model.Request, core serv
}

func (ws *WiretapService) HandleHttpRequest(request *model.Request) {

ws.handleHttpRequest(request)
}

func (ws *WiretapService) HandleWebsocketRequest(request *model.Request) {
ws.handleWebsocketRequest(request)
}
Loading

0 comments on commit 5e9102a

Please sign in to comment.