Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for setting basic auth in custom proxy #24

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions https.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -85,7 +86,6 @@ func (proxy *ProxyHttpServer) connectDialContext(ctx *ProxyCtx, network, addr st
}

func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request) {

ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy}
hij, ok := w.(http.Hijacker)
if !ok {
Expand Down Expand Up @@ -217,6 +217,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
}
req, resp := proxy.filterRequest(req, ctx)
if resp == nil {
if err := proxy.addBasicAuth(r, false); err != nil {
ctx.Warnf("Error adding basic auth credential to request %v", err)
}
if err := req.Write(targetSiteCon); err != nil {
httpError(proxyClient, ctx, err)
return
Expand Down Expand Up @@ -264,7 +267,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
// Set the RoundTripper on the ProxyCtx within the `HandleConnect` action of goproxy, then
// inject the roundtripper here in order to use a custom round tripper while mitm.
var ctx = &ProxyCtx{Req: req, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy, UserData: ctx.UserData, RoundTripper: ctx.RoundTripper}
if err != nil && err != io.EOF {
if err != nil && errors.Is(err, io.EOF) {
return
}
if err != nil {
Expand All @@ -289,6 +292,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
return
}
removeProxyHeaders(ctx, req)
if err := proxy.addBasicAuth(req, true); err != nil {
ctx.Warnf("Error adding basic auth credential to request %v", err)
}
resp, err = ctx.RoundTrip(req)
if err != nil {
ctx.Warnf("Cannot read TLS response from mitm'd server %v", err)
Expand Down Expand Up @@ -441,6 +447,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy strin
Host: addr,
Header: make(http.Header),
}
if user := u.User; user != nil {
connectReq.Header.Set("Proxy-Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(user.String())))
}
if connectReqHandler != nil {
connectReqHandler(connectReq)
}
Expand Down Expand Up @@ -486,6 +495,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy strin
Host: addr,
Header: make(http.Header),
}
if user := u.User; user != nil {
connectReq.Header.Set("Proxy-Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(user.String())))
}
if connectReqHandler != nil {
connectReqHandler(connectReq)
}
Expand Down Expand Up @@ -549,6 +561,9 @@ func (proxy *ProxyHttpServer) connectDialProxyWithContext(ctx *ProxyCtx, proxyHo
Host: host,
Header: make(http.Header),
}
if err := proxy.addBasicAuth(connectReq, true); err != nil {
return nil, err
}
connectReq.Write(c)
// Read response.
// Okay to use and discard buffered reader here, because
Expand Down Expand Up @@ -605,5 +620,8 @@ func httpsProxyAddr(reqURL *url.URL, httpsProxy string) (string, error) {
service = proxyURL.Scheme
}

if proxyURL.User != nil {
return fmt.Sprintf("%s://%s@%s:%s", proxyURL.Scheme, proxyURL.User.String(), proxyURL.Hostname(), service), nil
}
return fmt.Sprintf("%s://%s:%s", proxyURL.Scheme, proxyURL.Hostname(), service), nil
}
19 changes: 10 additions & 9 deletions https_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ var proxytests = map[string]struct {
url string
expectProxy string
}{
"do not proxy without a proxy configured": {"", "", "", "https://foo.bar/baz", ""},
"proxy with a proxy configured": {"", "daproxy", "", "https://foo.bar/baz", "http://daproxy:http"},
"proxy without a scheme": {"", "daproxy", "", "//foo.bar/baz", "http://daproxy:http"},
"proxy with a proxy configured with a port": {"", "http://daproxy:123", "", "https://foo.bar/baz", "http://daproxy:123"},
"proxy with an https proxy configured": {"", "https://daproxy", "", "https://foo.bar/baz", "https://daproxy:https"},
"proxy with a non-matching no_proxy": {"other.bar", "daproxy", "", "https://foo.bar/baz", "http://daproxy:http"},
"do not proxy with a full no_proxy match": {"foo.bar", "daproxy", "", "https://foo.bar/baz", ""},
"do not proxy with a suffix no_proxy match": {".bar", "daproxy", "", "https://foo.bar/baz", ""},
"proxy with an custom https proxy": {"", "https://daproxy", "https://customproxy", "https://foo.bar/baz", "https://customproxy:https"},
"do not proxy without a proxy configured": {"", "", "", "https://foo.bar/baz", ""},
"proxy with a proxy configured": {"", "daproxy", "", "https://foo.bar/baz", "http://daproxy:http"},
"proxy without a scheme": {"", "daproxy", "", "//foo.bar/baz", "http://daproxy:http"},
"proxy with a proxy configured with a port": {"", "http://daproxy:123", "", "https://foo.bar/baz", "http://daproxy:123"},
"proxy with an https proxy configured": {"", "https://daproxy", "", "https://foo.bar/baz", "https://daproxy:https"},
"proxy with a non-matching no_proxy": {"other.bar", "daproxy", "", "https://foo.bar/baz", "http://daproxy:http"},
"do not proxy with a full no_proxy match": {"foo.bar", "daproxy", "", "https://foo.bar/baz", ""},
"do not proxy with a suffix no_proxy match": {".bar", "daproxy", "", "https://foo.bar/baz", ""},
"proxy with an custom https proxy": {"", "https://daproxy", "https://customproxy", "https://foo.bar/baz", "https://customproxy:https"},
"proxy with an custom https proxy with basic auth": {"", "https://daproxy", "https://user:password@customproxy", "https://foo.bar/baz", "https://user:password@customproxy:https"},
}

var envKeys = []string{"no_proxy", "http_proxy", "https_proxy", "NO_PROXY", "HTTP_PROXY", "HTTPS_PROXY"}
Expand Down
46 changes: 43 additions & 3 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package goproxy

import (
"bufio"
"encoding/base64"
"io"
"log"
"net"
Expand Down Expand Up @@ -102,6 +103,7 @@ func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req
}
return
}

func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *ProxyCtx) (resp *http.Response) {
resp = respOrig
for _, h := range proxy.respHandlers {
Expand All @@ -111,6 +113,37 @@ func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *Proxy
return
}

func (proxy *ProxyHttpServer) addBasicAuth(r *http.Request, forceHTTPS bool) error {
if r.Header.Get("Proxy-Authorization") != "" {
return nil
}

var err error
var parsed *url.URL
switch {
case r.URL.Scheme == "http":
parsed, err = url.Parse(proxy.HttpProxyAddr)
if err != nil {
return err
}
case forceHTTPS || r.URL.Scheme == "https":
parsed, err = url.Parse(proxy.HttpsProxyAddr)
if err != nil {
return err
}
}

if parsed == nil {
return nil
}

if user := parsed.User; user != nil {
r.Header.Set("Proxy-Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(user.String())))
}

return nil
}

func removeProxyHeaders(ctx *ProxyCtx, r *http.Request) {
r.RequestURI = "" // this must be reset when serving a request with the client
ctx.Logf("Sending request %v %v", r.Method, r.URL.String())
Expand Down Expand Up @@ -151,6 +184,9 @@ func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request)

if resp == nil {
removeProxyHeaders(ctx, r)
if err := proxy.addBasicAuth(r, false); err != nil {
ctx.Warnf("Error adding basic auth credential to request %v", err)
}
resp, err = ctx.RoundTrip(r)
if err != nil {
ctx.Error = err
Expand Down Expand Up @@ -213,8 +249,8 @@ func WithHttpsProxyAddr(httpsProxyAddr string) ProxyHttpServerOptions {
// NewProxyHttpServer creates and returns a proxy server, logging to stderr by default
func NewProxyHttpServer(opts ...ProxyHttpServerOptions) *ProxyHttpServer {
appliedOpts := &options{
httpProxyAddr: "",
httpsProxyAddr: "",
httpProxyAddr: "",
httpsProxyAddr: "",
}
for _, opt := range opts {
opt.apply(appliedOpts)
Expand All @@ -237,15 +273,19 @@ func NewProxyHttpServer(opts ...ProxyHttpServerOptions) *ProxyHttpServer {
if appliedOpts.httpProxyAddr != "" {
proxy.HttpProxyAddr = appliedOpts.httpProxyAddr
httpProxyCfg.HTTPProxy = appliedOpts.httpProxyAddr
} else {
proxy.HttpProxyAddr = httpProxyCfg.HTTPProxy
}

if appliedOpts.httpsProxyAddr != "" {
proxy.HttpsProxyAddr = appliedOpts.httpsProxyAddr
httpProxyCfg.HTTPSProxy = appliedOpts.httpsProxyAddr
} else {
proxy.HttpsProxyAddr = httpProxyCfg.HTTPSProxy
}

proxy.ConnectDial = dialerFromProxy(&proxy)

if appliedOpts.httpProxyAddr != "" || appliedOpts.httpsProxyAddr != "" {
proxy.Tr.Proxy = func(req *http.Request) (*url.URL, error) {
return httpProxyCfg.ProxyFunc()(req.URL)
Expand Down
148 changes: 148 additions & 0 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"image"
"io"
"io/ioutil"
Expand All @@ -32,6 +33,49 @@ var https = httptest.NewTLSServer(nil)
var srv = httptest.NewServer(nil)
var fs = httptest.NewServer(http.FileServer(http.Dir(".")))

const (
authUser = "user"
authPass = "pass"
proxyAuthorizationHeader = "Proxy-Authorization"
)

func authed(r *http.Request) bool {
authheader := strings.SplitN(r.Header.Get(proxyAuthorizationHeader), " ", 2)
r.Header.Del(proxyAuthorizationHeader)
if len(authheader) != 2 || authheader[0] != "Basic" {
return false
}
userpassraw, err := base64.StdEncoding.DecodeString(authheader[1])
if err != nil {
return false
}
userpass := strings.SplitN(string(userpassraw), ":", 2)
if len(userpass) != 2 {
return false
}
user, pass := userpass[0], userpass[1]
if user != authUser && pass != authPass {
return false
}
return true
}

var auth = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !authed(r) {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.WriteHeader(http.StatusOK)
}))

var authTLS = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !authed(r) {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.WriteHeader(http.StatusOK)
}))

type QueryHandler struct{}

func (QueryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -73,6 +117,22 @@ func getOrFail(url string, client *http.Client, t *testing.T) []byte {
return txt
}

func authURL() string {
url, err := url.Parse(auth.URL)
if err != nil {
panic(err)
}
return fmt.Sprintf("%s://%s:%s@%s", url.Scheme, authUser, authPass, url.Host)
}

func authTLSURL() string {
url, err := url.Parse(authTLS.URL)
if err != nil {
panic(err)
}
return fmt.Sprintf("%s://%s:%s@%s", url.Scheme, authUser, authPass, url.Host)
}

func localFile(url string) string { return fs.URL + "/" + url }
func localTls(url string) string { return https.URL + url }

Expand Down Expand Up @@ -436,6 +496,94 @@ func TestSimpleMitm(t *testing.T) {
}
}

func TestAuthed_Deny(t *testing.T) {
proxy := goproxy.NewProxyHttpServer(goproxy.WithHttpProxyAddr(auth.URL))

client, l := oneShotProxy(proxy, t)
defer l.Close()

resp, err := client.Head("http://google.com")
panicOnErr(err, "resp to HEAD")
if resp.StatusCode != http.StatusUnauthorized {
t.Error("Status should be a 401")
}
}

func TestAuthed_Pass(t *testing.T) {
proxy := goproxy.NewProxyHttpServer(goproxy.WithHttpProxyAddr(authURL()))

client, l := oneShotProxy(proxy, t)
defer l.Close()

resp, err := client.Head("http://google.com")
panicOnErr(err, "resp to HEAD")
if resp.StatusCode != http.StatusOK {
t.Error("Status should be a 200")
}
}

func TestAuthed_HTTPS(t *testing.T) {
proxy := goproxy.NewProxyHttpServer(goproxy.WithHttpsProxyAddr(authTLSURL()))
proxy.OnRequest(goproxy.ReqHostIs("https://foo")).HandleConnect(goproxy.AlwaysMitm)

_, l := oneShotProxy(proxy, t)
defer l.Close()

c, err := tls.Dial("tcp", https.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Fatal("cannot dial to tcp server", err)
}
c.Close()

c2, err := net.Dial("tcp", l.Listener.Addr().String())
if err != nil {
t.Fatal("dialing to proxy", err)
}
creq, err := http.NewRequest("CONNECT", https.URL, nil)
if err != nil {
t.Fatal("create new request", creq)
}
creq.Write(c2)
c2buf := bufio.NewReader(c2)
resp, err := http.ReadResponse(c2buf, creq)
if err != nil || resp.StatusCode != 200 {
t.Fatalf("Cannot CONNECT through proxy %v %d", err, resp.StatusCode)
}
}

func TestAuthed_HTTPSDeny(t *testing.T) {
proxy := goproxy.NewProxyHttpServer(goproxy.WithHttpsProxyAddr(authTLS.URL))
proxy.OnRequest(goproxy.ReqHostIs("https://foo")).HandleConnect(goproxy.AlwaysMitm)

_, l := oneShotProxy(proxy, t)
defer l.Close()

c, err := tls.Dial("tcp", https.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Fatal("cannot dial to tcp server", err)
}
c.Close()

c2, err := net.Dial("tcp", l.Listener.Addr().String())
if err != nil {
t.Fatal("dialing to proxy", err)
}
creq, err := http.NewRequest("CONNECT", https.URL, nil)
if err != nil {
t.Fatal("create new request", creq)
}
creq.Write(c2)
c2buf := bufio.NewReader(c2)
resp, err := http.ReadResponse(c2buf, creq)
if err != nil {
t.Fatal("Cannot CONNECT through proxy", err)
}
// if a CONNECT request is denited, goproxy returns a 502
if resp.StatusCode != 502 {
t.Fatal("response should have been denied", resp.StatusCode)
}
}

func TestConnectHandler(t *testing.T) {
proxy := goproxy.NewProxyHttpServer()
althttps := httptest.NewTLSServer(ConstantHanlder("althttps"))
Expand Down