Skip to content

Commit

Permalink
More proxy improvements (#20)
Browse files Browse the repository at this point in the history
* log requests within CONNECT tunnel

* set `Connection: close`

goproxy doesn't hang up the upstream connection when the downstream
connection closes. So, it'll read in the next request and then fail
to send it downstream. In this case it 502's. It would be nicer if
it could close the upstream connection preemtively so the cient
could re-dial. Elixir's HTTPoison tries pooling requests over the
CONNECT tunnel and this results in 502's bubbling up to the calling
app. Setting `Connection: close` prevents this pooling.
  • Loading branch information
btoews authored Oct 4, 2023
1 parent 92c315b commit 55322b4
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 87 deletions.
48 changes: 0 additions & 48 deletions cmd/tokenizer/log.go

This file was deleted.

7 changes: 6 additions & 1 deletion cmd/tokenizer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ func runServe() {

tkz := tokenizer.NewTokenizer(key)

server := &http.Server{Handler: loggingMiddleware(tkz)}
if len(os.Getenv("DEBUG")) != 0 {
tkz.ProxyHttpServer.Verbose = true
tkz.ProxyHttpServer.Logger = logrus.StandardLogger()
}

server := &http.Server{Handler: tkz}

go func() {
if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
Expand Down
25 changes: 8 additions & 17 deletions request_validator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tokenizer

import (
"errors"
"fmt"
"net/http"
"regexp"
Expand All @@ -26,15 +25,11 @@ func AllowHosts(hosts ...string) RequestValidator {
}

func (v allowedHosts) Validate(r *http.Request) error {
host := r.URL.Host
if host == "" {
host = r.Host
if r.Host == "" {
return fmt.Errorf("%w: no host in request", ErrBadRequest)
}
if host == "" {
return errors.New("coun't find host in request")
}
if _, allowed := v[host]; !allowed {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, host)
if _, allowed := v[r.Host]; !allowed {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, r.Host)
}
return nil
}
Expand All @@ -52,15 +47,11 @@ func AllowHostPattern(pattern *regexp.Regexp) RequestValidator {
}

func (v *allowedHostPattern) Validate(r *http.Request) error {
host := r.URL.Host
if host == "" {
host = r.Host
}
if host == "" {
return errors.New("coun't find host in request")
if r.Host == "" {
return fmt.Errorf("%w: no host in request", ErrBadRequest)
}
if match := (*regexp.Regexp)(v).MatchString(host); !match {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, host)
if match := (*regexp.Regexp)(v).MatchString(r.Host); !match {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, r.Host)
}
return nil
}
130 changes: 110 additions & 20 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
"net"
"net/http"
"strings"
"time"
"unicode"

"github.com/elazarl/goproxy"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/nacl/box"
"golang.org/x/exp/maps"
)

var FilteredHeaders = []string{headerProxyAuthorization, headerProxyTokenizer}
Expand All @@ -41,10 +43,6 @@ type tokenizer struct {
pub *[32]byte
}

var _ goproxy.HttpsHandler = (*tokenizer)(nil)
var _ goproxy.ReqHandler = (*tokenizer)(nil)
var _ http.Handler = new(tokenizer)

func NewTokenizer(openKey string) *tokenizer {
privBytes, err := hex.DecodeString(openKey)
if err != nil {
Expand All @@ -71,8 +69,9 @@ func NewTokenizer(openKey string) *tokenizer {
}
proxy.ConnectDial = nil
proxy.ConnectDialWithReq = nil
proxy.OnRequest().HandleConnect(tkz)
proxy.OnRequest().Do(tkz)
proxy.OnRequest().HandleConnectFunc(tkz.HandleConnect)
proxy.OnRequest().DoFunc(tkz.HandleRequest)
proxy.OnResponse().DoFunc(tkz.HandleResponse)

return tkz
}
Expand All @@ -81,42 +80,84 @@ func (t *tokenizer) SealKey() string {
return hex.EncodeToString(t.pub[:])
}

// HandleConnect implements goproxy.HttpsHandler
// data that we can pass around between callbacks
type proxyUserData struct {
// processors from our handling of the initial CONNECT request if this is a
// tunneled connection.
connectProcessors []RequestProcessor

// start time of the CONNECT request if this is a tunneled connection.
connectStart time.Time
connLog logrus.FieldLogger

// start time of the current request. gets reset between requests within a
// tunneled connection.
requestStart time.Time
reqLog logrus.FieldLogger
}

// HandleConnect implements goproxy.FuncHttpsHandler
func (t *tokenizer) HandleConnect(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
pud := &proxyUserData{
connLog: logrus.WithField("connect_host", host),
connectStart: time.Now(),
}

_, port, _ := strings.Cut(host, ":")
if port == "443" {
logrus.WithField("host", host).Warn("attempt to proxy to https downstream")
pud.connLog.Warn("attempt to proxy to https downstream")
ctx.Resp = errorResponse(ErrBadRequest)
return goproxy.RejectConnect, ""
}

processors, err := t.processorsFromRequest(ctx.Req)
if err != nil {
var err error
if pud.connectProcessors, err = t.processorsFromRequest(ctx.Req); err != nil {
pud.connLog.WithError(err).Warn("find processor (CONNECT)")
ctx.Resp = errorResponse(err)
return goproxy.RejectConnect, ""
}

ctx.UserData = processors
ctx.UserData = pud

return goproxy.HTTPMitmConnect, host
}

// Handle implements goproxy.FuncReqHandler
func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
var processors []RequestProcessor
if ctx.UserData != nil {
processors = ctx.UserData.([]RequestProcessor)
// HandleRequest implements goproxy.FuncReqHandler
func (t *tokenizer) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
if ctx.UserData == nil {
ctx.UserData = &proxyUserData{}
}
pud, ok := ctx.UserData.(*proxyUserData)

if !ok || !pud.requestStart.IsZero() || pud.reqLog != nil {
logrus.Warn("bad proxyUserData")
return nil, errorResponse(ErrInternal)
}

pud.requestStart = time.Now()
if pud.connLog != nil {
pud.reqLog = pud.connLog
} else {
pud.reqLog = logrus.StandardLogger()
}
pud.reqLog = pud.reqLog.WithFields(logrus.Fields{
"method": req.Method,
"host": req.Host,
"path": req.URL.Path,
"queryKeys": strings.Join(maps.Keys(req.URL.Query()), ", "),
})

processors := append([]RequestProcessor(nil), pud.connectProcessors...)
if reqProcessors, err := t.processorsFromRequest(req); err != nil {
logrus.WithError(err).Warn("find processor")
pud.reqLog.WithError(err).Warn("find processor")
return req, errorResponse(err)
} else {
processors = append(processors, reqProcessors...)
}

for _, processor := range processors {
if err := processor(req); err != nil {
logrus.WithError(err).Warn("run processor")
pud.reqLog.WithError(err).Warn("run processor")
return nil, errorResponse(ErrBadRequest)
}
}
Expand All @@ -128,6 +169,47 @@ func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Requ
return req, nil
}

// HandleResponse implements goproxy.FuncRespHandler
func (t *tokenizer) HandleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response {
// This callback is hit twice if there was an error in the downstream
// request. The first time a nil request is given and the second time we're
// given whatever we returned the first time. Skip logging on the second
// call. This should continue to work okay if
// https://github.com/elazarl/goproxy/pull/512 is ever merged.
if ctx.Error != nil && resp != nil {
return resp
}

pud, ok := ctx.UserData.(*proxyUserData)
if !ok || pud.requestStart.IsZero() || pud.reqLog == nil {
logrus.Warn("missing proxyUserData")
return errorResponse(ErrInternal)
}

log := pud.reqLog.WithField("durMS", int64(time.Since(pud.requestStart)/time.Millisecond))

if !pud.connectStart.IsZero() {
log = log.WithField("connDurMS", int64(time.Since(pud.connectStart)/time.Millisecond))
}
if resp != nil {
log = log.WithField("status", resp.StatusCode)
resp.Header.Set("Connection", "close")
}

// reset pud for next request in tunnel
pud.requestStart = time.Time{}
pud.reqLog = nil

if ctx.Error != nil {
log.WithError(ctx.Error).Warn()
return errorResponse(ctx.Error)
}

log.Info()

return resp
}

func (t *tokenizer) processorsFromRequest(req *http.Request) ([]RequestProcessor, error) {
hdrs := req.Header[headerProxyTokenizer]
processors := make([]RequestProcessor, 0, len(hdrs))
Expand Down Expand Up @@ -215,13 +297,20 @@ func errorResponse(err error) *http.Response {
status = http.StatusBadGateway
}

return &http.Response{StatusCode: status, Body: io.NopCloser(bytes.NewReader([]byte(err.Error())))}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewReader([]byte(err.Error()))),
Header: make(http.Header),
}
}

func forceTLSDialer(network, addr string) (net.Conn, error) {
if network != "tcp" {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("%w: dialing network %s not supported", ErrBadRequest, network)
}

hostname, port, _ := strings.Cut(addr, ":")
if hostname == "" {
return nil, fmt.Errorf("%w: attempt to dial without host: %q", ErrBadRequest, addr)
Expand All @@ -233,5 +322,6 @@ func forceTLSDialer(network, addr string) (net.Conn, error) {
port = "443"
}
addr = fmt.Sprintf("%s:%s", hostname, port)

return tls.Dial("tcp", addr, &tls.Config{RootCAs: upstreamTrust})
}
2 changes: 1 addition & 1 deletion tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestTokenizer(t *testing.T) {
// TLS error (proxy doesn't trust upstream)
resp, err := client.Get(appURL)
assert.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)

// make proxy trust upstream
upstreamTrust.AddCert(appServer.Certificate())
Expand Down

0 comments on commit 55322b4

Please sign in to comment.