Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run proxy finder function and add result to request context #87

Merged
merged 2 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -87,7 +87,7 @@ func main() {

pacWrapper := NewPACWrapper(PACData{Port: *port})
proxyFinder := NewProxyFinder(pacURL, pacWrapper)
proxyHandler := NewProxyHandler(proxyFinder.findProxyForRequest, a, proxyFinder.blockProxy)
proxyHandler := NewProxyHandler(a, getProxyFromContext, proxyFinder.blockProxy)
mux := http.NewServeMux()
pacWrapper.SetupHandlers(mux)

Expand Down
35 changes: 21 additions & 14 deletions proxy.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,7 @@ type ProxyHandler struct {

type proxyFunc func(*http.Request) (*url.URL, error)

func NewProxyHandler(proxy proxyFunc, auth *authenticator, block func(string)) ProxyHandler {
func NewProxyHandler(auth *authenticator, proxy proxyFunc, block func(string)) ProxyHandler {
return ProxyHandler{&http.Transport{Proxy: proxy}, auth, block}
}

Expand Down Expand Up @@ -69,22 +69,20 @@ func (ph ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {

func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) {
// Establish a connection to the server, or an upstream proxy.
u, err := ph.transport.Proxy(req)
id := req.Context().Value(contextKeyID)
proxy, err := ph.transport.Proxy(req)
if err != nil {
log.Printf("[%d] Error finding proxy for %v: %v", id, req.Host, err)
w.WriteHeader(http.StatusInternalServerError)
return
log.Printf("[%d] Error finding proxy for request: %v", id, err)
}
var server net.Conn
if u == nil {
server, err = net.Dial("tcp", req.Host)
if proxy == nil {
server, err = connectDirect(req)
} else {
server, err = connectViaProxy(req, u.Host, ph.auth)
server, err = connectViaProxy(req, proxy.Host, ph.auth)
var dialErr *dialError
if errors.As(err, &dialErr) {
log.Printf("[%d] Temporarily blocking unreachable proxy: %q", id, u.Host)
ph.block(u.Host)
log.Printf("[%d] Temporarily blocking proxy: %q", id, proxy.Host)
ph.block(proxy.Host)
}
}
if err != nil {
Expand Down Expand Up @@ -130,12 +128,21 @@ func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) {
go func() { _, _ = io.Copy(client, server); client.Close() }()
}

func connectDirect(req *http.Request) (net.Conn, error) {
server, err := net.Dial("tcp", req.Host)
if err != nil {
id := req.Context().Value(contextKeyID)
log.Printf("[%d] Error dialling host %s: %v", id, req.Host, err)
}
return server, err
}

func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net.Conn, error) {
id := req.Context().Value(contextKeyID)
var tr transport
defer tr.Close()
if err := tr.dial("tcp", proxy); err != nil {
log.Printf("[%d] Error dialling %s: %v", id, proxy, err)
log.Printf("[%d] Error dialling proxy %s: %v", id, proxy, err)
return nil, err
}
resp, err := tr.RoundTrip(req)
Expand Down Expand Up @@ -176,7 +183,7 @@ func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, au
resp, err := ph.transport.RoundTrip(req)
if err != nil {
log.Printf("[%d] Error forwarding request: %v", id, err)
w.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(http.StatusBadGateway)
var dialErr *dialError
if errors.As(err, &dialErr) && dialErr.address != req.Host {
log.Printf("[%d] Temporarily blocking unreachable proxy: %q",
Expand All @@ -196,7 +203,7 @@ func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, au
resp, err = auth.do(req, ph.transport)
if err != nil {
log.Printf("[%d] Error forwarding request (with auth): %v", id, err)
w.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(http.StatusBadGateway)
return
}
defer resp.Body.Close()
Expand Down
29 changes: 19 additions & 10 deletions proxy_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@ package main

import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -54,17 +55,17 @@ func (tp testProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

func newDirectProxy() ProxyHandler {
return NewProxyHandler(
func(r *http.Request) (*url.URL, error) { return nil, nil },
nil,
func(string) {},
)
return NewProxyHandler(nil, http.ProxyURL(nil), func(string) {})
}

func newChildProxy(parent *httptest.Server) ProxyHandler {
return NewProxyHandler(func(r *http.Request) (*url.URL, error) {
return &url.URL{Host: parent.Listener.Addr().String()}, nil
}, nil, func(string) {})
func newChildProxy(parent *httptest.Server) http.Handler {
parentURL := &url.URL{Host: parent.Listener.Addr().String()}
childProxy := NewProxyHandler(nil, getProxyFromContext, func(string) {})
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), contextKeyProxy, parentURL)
reqWithProxy := req.WithContext(ctx)
childProxy.ServeHTTP(w, reqWithProxy)
})
}

func proxyServer(t *testing.T, proxy *httptest.Server) proxyFunc {
Expand Down Expand Up @@ -336,3 +337,11 @@ func TestConnectResponseHasCorrectNewlines(t *testing.T) {
assert.NotContains(t, noCRLFs, "\r", "response contains unmatched CR")
assert.NotContains(t, noCRLFs, "\n", "response contains unmatched LF")
}

func TestConnectToNonExistentHost(t *testing.T) {
proxy := httptest.NewServer(newDirectProxy())
defer proxy.Close()
client := http.Client{Transport: &http.Transport{Proxy: proxyServer(t, proxy)}}
_, err := client.Get("https://nonexistent.test")
require.Error(t, err)
}
22 changes: 20 additions & 2 deletions proxyfinder.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
package main

import (
"context"
"errors"
"log"
"net"
Expand All @@ -24,6 +25,16 @@ import (
"sync"
)

const contextKeyProxy = contextKey("proxy")

func getProxyFromContext(req *http.Request) (*url.URL, error) {
if value := req.Context().Value(contextKeyProxy); value != nil {
proxy := value.(*url.URL)
return proxy, nil
}
return nil, nil
}

type ProxyFinder struct {
runner *PACRunner
fetcher *pacFetcher
Expand Down Expand Up @@ -53,7 +64,14 @@ func (pf *ProxyFinder) WrapHandler(next http.Handler) http.Handler {
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
pf.checkForUpdates()
next.ServeHTTP(w, req)
proxy, err := pf.findProxyForRequest(req)
if err != nil {
log.Printf("[%d] %v", req.Context().Value(contextKeyID), err)
w.WriteHeader(http.StatusInternalServerError)
return
}
ctx := context.WithValue(req.Context(), contextKeyProxy, proxy)
next.ServeHTTP(w, req.WithContext(ctx))
})
}

Expand Down