diff --git a/ctx.go b/ctx.go index 76a4c6dc..55f9ca3f 100644 --- a/ctx.go +++ b/ctx.go @@ -25,8 +25,9 @@ type ProxyCtx struct { // call of RespHandler UserData interface{} // Will connect a request to a response - Session int64 - proxy *ProxyHttpServer + Session int64 + proxy *ProxyHttpServer + ConnectAction ConnectActionLiteral } type RoundTripper interface { diff --git a/https.go b/https.go index a60a1953..114718fc 100644 --- a/https.go +++ b/https.go @@ -43,9 +43,10 @@ var ( ) type ConnectAction struct { - Action ConnectActionLiteral - Hijack func(req *http.Request, client net.Conn, ctx *ProxyCtx) - TLSConfig func(host string, ctx *ProxyCtx) (*tls.Config, error) + Action ConnectActionLiteral + Hijack func(req *http.Request, client net.Conn, ctx *ProxyCtx) + TLSConfig func(host string, ctx *ProxyCtx) (*tls.Config, error) + MitmMutateRequest func(req *http.Request, ctx *ProxyCtx) } func stripPort(s string) string { @@ -114,6 +115,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request break } } + ctx.ConnectAction = todo.Action switch todo.Action { case ConnectAccept: if !hasPort.MatchString(host) { @@ -264,7 +266,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request req, err := http.ReadRequest(clientTlsReader) // Set the RoundTripper on the ProxyCtx within the `HandleConnect` action of goproxy, then // inject the roundtripper here in order to use a custom round tripper while mitm. - var ctx = &ProxyCtx{Req: req, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy, UserData: ctx.UserData, RoundTripper: ctx.RoundTripper} + var ctx = &ProxyCtx{Req: req, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy, UserData: ctx.UserData, RoundTripper: ctx.RoundTripper, ConnectAction: ctx.ConnectAction} if err != nil && err != io.EOF { return } @@ -273,6 +275,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request return } req.RemoteAddr = r.RemoteAddr // since we're converting the request, need to carry over the original connecting IP as well + if todo.MitmMutateRequest != nil { + todo.MitmMutateRequest(req, ctx) + } ctx.Logf("req %v", r.Host) if !httpsRegexp.MatchString(req.URL.String()) { diff --git a/proxy_test.go b/proxy_test.go index 26ee2d85..c15dea04 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -41,9 +41,26 @@ func (QueryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { io.WriteString(w, req.Form.Get("result")) } +type HeadersHandler struct{} + +// This handlers returns a body with a string containing all the request headers it received. +func (HeadersHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + var sb strings.Builder + for name, values := range req.Header { + for _, value := range values { + sb.WriteString(name) + sb.WriteString(": ") + sb.WriteString(value) + sb.WriteString(";") + } + } + io.WriteString(w, sb.String()) +} + func init() { http.DefaultServeMux.Handle("/bobo", ConstantHanlder("bobo")) http.DefaultServeMux.Handle("/query", QueryHandler{}) + http.DefaultServeMux.Handle("/headers", HeadersHandler{}) } type ConstantHanlder string @@ -436,6 +453,33 @@ func TestSimpleMitm(t *testing.T) { } } +func TestMitmMutateRequest(t *testing.T) { + mitmMutateRequest := func(req *http.Request, ctx *goproxy.ProxyCtx) { + // We inject a header in the request + req.Header.Set("Mitm-Header-Inject", "true") + } + mitmConnect := &goproxy.ConnectAction{ + Action: goproxy.ConnectMitm, + TLSConfig: goproxy.TLSConfigFromCA(&goproxy.GoproxyCa), + MitmMutateRequest: mitmMutateRequest, + } + var mitm goproxy.FuncHttpsHandler = func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + return mitmConnect, host + } + + proxy := goproxy.NewProxyHttpServer() + proxy.OnRequest().HandleConnect(mitm) + + client, l := oneShotProxy(proxy, t) + defer l.Close() + + r := string(getOrFail(https.URL+"/headers", client, t)) + if !strings.Contains(r, "Mitm-Header-Inject: true") { + t.Error("Expected response body to contain the MITM injected header. Got instead: ", r) + } + +} + func TestConnectHandler(t *testing.T) { proxy := goproxy.NewProxyHttpServer() althttps := httptest.NewTLSServer(ConstantHanlder("althttps"))