diff --git a/daemon/api.go b/daemon/api.go index 91db365..9a3a399 100644 --- a/daemon/api.go +++ b/daemon/api.go @@ -4,9 +4,11 @@ package daemon import ( - "fmt" - "github.com/pb33f/wiretap/shared" - "net/http" + "crypto/tls" + "fmt" + "net/http" + + "github.com/pb33f/wiretap/shared" ) type wiretapTransport struct { @@ -15,6 +17,8 @@ type wiretapTransport struct { } func newWiretapTransport() *wiretapTransport { + // Disable ssl cert checks + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} return &wiretapTransport{ originalTransport: http.DefaultTransport, } @@ -39,13 +43,21 @@ func (ws *WiretapService) callAPI(req *http.Request, responseChan chan *http.Res configStore, _ := ws.controlsStore.Get(shared.ConfigKey) // create a new request from the original request, but replace the path - config := configStore.(*shared.WiretapConfiguration) - - resp, err := client.Do(cloneRequest(req, + newReq := cloneRequest(req, config.RedirectProtocol, config.RedirectHost, - config.RedirectPort)) + config.RedirectPort) + // re-write referer + if (newReq.Header.Get("Referer") != "") { + // retain original referer for logging + newReq.Header.Set("X-Original-Referer", newReq.Header.Get("Referer")) + newReq.Header.Set("Referer", reconstructURL(req, + config.RedirectProtocol, + config.RedirectHost, + config.RedirectPort)) + } + resp, err := client.Do(newReq) if err != nil { errorChan <- err diff --git a/daemon/wiretap_utils.go b/daemon/wiretap_utils.go index dd5fd31..8657677 100644 --- a/daemon/wiretap_utils.go +++ b/daemon/wiretap_utils.go @@ -18,6 +18,30 @@ func extractHeaders(resp *http.Response) map[string]any { return headers } +func reconstructURL(r *http.Request, protocol, host, port string) string { + url := fmt.Sprintf("%s://%s", protocol, host) + // pattern := "%s://%s:%s/%s?%s" + // urlString := fmt.Sprintf(pattern, protocol, host, port, r.URL.Path, r.URL.RawQuery) + if port != "" { + url += fmt.Sprintf(":%s", port) + // pattern = "%s://%s/%s?%s" + // urlString = fmt.Sprintf(pattern, protocol, host, r.URL.Path, r.URL.RawQuery) + } + if r.URL.Path != "" { + url += r.URL.Path + } + if r.URL.RawQuery != "" { + url += fmt.Sprintf("?%s", r.URL.RawQuery) + // pattern = "%s://%s:%s/%s" + // urlString = fmt.Sprintf(pattern, protocol, host, port, r.URL.Path) + // if port == "" { + // pattern = "%s://%s/%s" + // urlString = fmt.Sprintf(pattern, protocol, host, r.URL.Path) + // } + } + return url +} + func cloneRequest(r *http.Request, protocol, host, port string) *http.Request { // todo: replace with config/server etc. // todo: check query params @@ -27,23 +51,9 @@ func cloneRequest(r *http.Request, protocol, host, port string) *http.Request { _ = r.Body.Close() r.Body = io.NopCloser(bytes.NewBuffer(b)) - pattern := "%s://%s:%s?%s" - urlString := fmt.Sprintf(pattern, protocol, host, port, r.URL.RawQuery) - if port == "" { - pattern = "%s://%s?%s" - urlString = fmt.Sprintf(pattern, protocol, host, r.URL.RawQuery) - } - if r.URL.RawQuery == "" { - pattern = "%s://%s:%s" - urlString = fmt.Sprintf(pattern, protocol, host, port) - if port == "" { - pattern = "%s://%s" - urlString = fmt.Sprintf(pattern, protocol, host) - } - } - - newBaseURL := urlString - newReq, _ := http.NewRequest(r.Method, newBaseURL, io.NopCloser(bytes.NewBuffer(b))) + // create cloned request + newURL := reconstructURL(r, protocol, host, port) + newReq, _ := http.NewRequest(r.Method, newURL, io.NopCloser(bytes.NewBuffer(b))) newReq.Header = r.Header return newReq } diff --git a/daemon/wiretap_utils_test.go b/daemon/wiretap_utils_test.go new file mode 100644 index 0000000..d62642c --- /dev/null +++ b/daemon/wiretap_utils_test.go @@ -0,0 +1,37 @@ +// Copyright 2023 Princess B33f Heavy Industries / Dave Shanley +// SPDX-License-Identifier: MIT + +package daemon + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReconstructURL(t *testing.T){ + protocol := "http" + host := "localhost" + port := "8000" + // Making sure trailing slashes are accounted for correctly + r, _ := http.NewRequest("GET", "http://localhost:1337/", nil) + assert.Equal(t, "http://localhost:8000/", reconstructURL(r, protocol, host, port)) + r, _ = http.NewRequest("GET", "http://localhost:1337", nil) + assert.Equal(t, "http://localhost:8000", reconstructURL(r, protocol, host, port)) + // Adding port correctly + r, _ = http.NewRequest("GET", "http://localhost/", nil) + assert.Equal(t, "http://localhost/", reconstructURL(r, protocol, host, "")) + r, _ = http.NewRequest("GET", "http://localhost:8000", nil) + assert.Equal(t, "http://localhost", reconstructURL(r, protocol, host, "")) + // Adding paths correctly + r, _ = http.NewRequest("POST", "http://localhost/dalek", nil) + assert.Equal(t, "http://localhost:8000/dalek", reconstructURL(r, protocol, host, port)) + r, _ = http.NewRequest("PUT", "http://localhost/dalek/1337", nil) + assert.Equal(t, "http://localhost/dalek/1337", reconstructURL(r, protocol, host, "")) + // Adding query params correctly + r, _ = http.NewRequest("GET", "http://localhost?doctor=who", nil) + assert.Equal(t, "http://localhost:8000?doctor=who", reconstructURL(r, protocol, host, port)) + r, _ = http.NewRequest("GET", "http://localhost:1337?doctor=who", nil) + assert.Equal(t, "http://localhost?doctor=who", reconstructURL(r, protocol, host, "")) +} \ No newline at end of file