From a9b6ba287b5f3f3d75d6896664a8eceda27551f6 Mon Sep 17 00:00:00 2001 From: Craig Shannon Date: Thu, 13 Jul 2023 08:21:35 -0500 Subject: [PATCH] Do not return a denyError for DNS resolution failures (#194) * dont return denial errors for dns resolution failures * fix test * move DNSError check into net.Error assertion, extend test * fix integration test --- cmd/integration_test.go | 2 +- pkg/smokescreen/smokescreen.go | 11 ++++++++++- pkg/smokescreen/smokescreen_test.go | 16 +++++++++++----- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/cmd/integration_test.go b/cmd/integration_test.go index 96771e96..ed1320cb 100644 --- a/cmd/integration_test.go +++ b/cmd/integration_test.go @@ -441,7 +441,7 @@ func validateProxyResponseWithUpstream(t *testing.T, test *TestCase, resp *http. t.Logf("HTTP Response: %#v", resp) if test.OverConnect { - a.Contains(err.Error(), "Failed to connect to remote host") + a.Contains(err.Error(), "Failed to resolve remote hostname") } else { a.Equal(http.StatusBadGateway, resp.StatusCode) } diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 2c3c61d4..69e9dd38 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -356,6 +356,11 @@ func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response { status = "Gateway timeout" code = http.StatusGatewayTimeout msg = "Timed out connecting to remote host: " + e.Error() + + } else if e, ok := err.(*net.DNSError); ok { + status = "Bad gateway" + code = http.StatusBadGateway + msg = "Failed to resolve remote hostname: " + e.Error() } else { status = "Bad gateway" code = http.StatusBadGateway @@ -620,9 +625,13 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { pctx.Error = denyError{err} return "", pctx.Error } + + // checkIfRequestShouldBeProxied can return an error if either the resolved address is disallowed, + // or if there is a DNS resolution failure. sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) if pctx.Error != nil { - return "", denyError{pctx.Error} + // DNS resolution failure + return "", pctx.Error } if !sctx.decision.allow { return "", denyError{errors.New(sctx.decision.reason)} diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 4e653407..4e01ef9a 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -406,11 +406,10 @@ func TestHealthcheck(t *testing.T) { var invalidHostCases = []struct { scheme string - expectErr bool proxyType string }{ - {"http", false, "http"}, - {"https", true, "connect"}, + {"http", "http"}, + {"https", "connect"}, } func TestInvalidHost(t *testing.T) { @@ -430,12 +429,19 @@ func TestInvalidHost(t *testing.T) { client, err := proxyClient(proxySrv.URL) r.NoError(err) + // This hostname does not exist and should never resolve resp, err := client.Get(fmt.Sprintf("%s://notarealhost.test", testCase.scheme)) - if testCase.expectErr { - r.Contains(err.Error(), "Request rejected by proxy") + if testCase.scheme == "https" { + r.Error(err) + r.Contains(err.Error(), "Bad gateway") } else { + // Plain HTTP r.NoError(err) r.Equal(http.StatusBadGateway, resp.StatusCode) + + defer resp.Body.Close() + b, _ := ioutil.ReadAll(resp.Body) + r.Contains(string(b), "Failed to resolve remote hostname") } entry := findCanonicalProxyDecision(logHook.AllEntries())