Skip to content

Commit

Permalink
Reconcile upstream dial addresses and request host/URL information
Browse files Browse the repository at this point in the history
My goodness that was complicated

Blessed be request.Context

Sort of
  • Loading branch information
mholt committed Sep 5, 2019
1 parent a60d54d commit 0830fba
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 183 deletions.
18 changes: 9 additions & 9 deletions listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,19 @@ var (
listenersMu sync.Mutex
)

// ParseListenAddr parses addr, a string of the form "network/host:port"
// ParseNetworkAddress parses addr, a string of the form "network/host:port"
// (with any part optional) into its component parts. Because a port can
// also be a port range, there may be multiple addresses returned.
func ParseListenAddr(addr string) (network string, addrs []string, err error) {
func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
var host, port string
network, host, port, err = SplitListenAddr(addr)
network, host, port, err = SplitNetworkAddress(addr)
if network == "" {
network = "tcp"
}
if err != nil {
return
}
if network == "unix" {
if network == "unix" || network == "unixgram" || network == "unixpacket" {
addrs = []string{host}
return
}
Expand All @@ -204,26 +204,26 @@ func ParseListenAddr(addr string) (network string, addrs []string, err error) {
return
}

// SplitListenAddr splits a into its network, host, and port components.
// SplitNetworkAddress splits a into its network, host, and port components.
// Note that port may be a port range, or omitted for unix sockets.
func SplitListenAddr(a string) (network, host, port string, err error) {
func SplitNetworkAddress(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx]))
a = a[idx+1:]
}
if network == "unix" {
if network == "unix" || network == "unixgram" || network == "unixpacket" {
host = a
return
}
host, port, err = net.SplitHostPort(a)
return
}

// JoinListenAddr combines network, host, and port into a single
// JoinNetworkAddress combines network, host, and port into a single
// address string of the form "network/host:port". Port may be a
// port range. For unix sockets, the network should be "unix" and
// the path to the socket should be given in the host argument.
func JoinListenAddr(network, host, port string) string {
func JoinNetworkAddress(network, host, port string) string {
var a string
if network != "" {
a = network + "/"
Expand Down
22 changes: 16 additions & 6 deletions listeners_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"testing"
)

func TestSplitListenerAddr(t *testing.T) {
func TestSplitNetworkAddress(t *testing.T) {
for i, tc := range []struct {
input string
expectNetwork string
Expand Down Expand Up @@ -67,8 +67,18 @@ func TestSplitListenerAddr(t *testing.T) {
expectNetwork: "unix",
expectHost: "/foo/bar",
},
{
input: "unixgram//foo/bar",
expectNetwork: "unixgram",
expectHost: "/foo/bar",
},
{
input: "unixpacket//foo/bar",
expectNetwork: "unixpacket",
expectHost: "/foo/bar",
},
} {
actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input)
actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err)
}
Expand All @@ -87,7 +97,7 @@ func TestSplitListenerAddr(t *testing.T) {
}
}

func TestJoinListenerAddr(t *testing.T) {
func TestJoinNetworkAddress(t *testing.T) {
for i, tc := range []struct {
network, host, port string
expect string
Expand Down Expand Up @@ -129,14 +139,14 @@ func TestJoinListenerAddr(t *testing.T) {
expect: "unix//foo/bar",
},
} {
actual := JoinListenAddr(tc.network, tc.host, tc.port)
actual := JoinNetworkAddress(tc.network, tc.host, tc.port)
if actual != tc.expect {
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
}
}
}

func TestParseListenerAddr(t *testing.T) {
func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct {
input string
expectNetwork string
Expand Down Expand Up @@ -194,7 +204,7 @@ func TestParseListenerAddr(t *testing.T) {
expectAddrs: []string{"localhost:0"},
},
} {
actualNetwork, actualAddrs, err := ParseListenAddr(tc.input)
actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err)
}
Expand Down
12 changes: 6 additions & 6 deletions modules/caddyhttp/caddyhttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (app *App) Validate() error {
lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers {
for _, addr := range srv.Listen {
netw, expanded, err := caddy.ParseListenAddr(addr)
netw, expanded, err := caddy.ParseNetworkAddress(addr)
if err != nil {
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func (app *App) Start() error {
}

for _, lnAddr := range srv.Listen {
network, addrs, err := caddy.ParseListenAddr(lnAddr)
network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil {
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
}
Expand Down Expand Up @@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error {

// create HTTP->HTTPS redirects
for _, addr := range srv.Listen {
netw, host, port, err := caddy.SplitListenAddr(addr)
netw, host, port, err := caddy.SplitNetworkAddress(addr)
if err != nil {
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
}
Expand All @@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error {
if httpPort == 0 {
httpPort = DefaultHTTPPort
}
httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort))
httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort))
lnAddrMap[httpRedirLnAddr] = struct{}{}

if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
Expand Down Expand Up @@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error {
var lnAddrs []string
mapLoop:
for addr := range lnAddrMap {
netw, addrs, err := caddy.ParseListenAddr(addr)
netw, addrs, err := caddy.ParseNetworkAddress(addr)
if err != nil {
continue
}
Expand All @@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error {
func (app *App) listenerTaken(network, address string) bool {
for _, srv := range app.Servers {
for _, addr := range srv.Listen {
netw, addrs, err := caddy.ParseListenAddr(addr)
netw, addrs, err := caddy.ParseNetworkAddress(addr)
if err != nil || netw != network {
continue
}
Expand Down
102 changes: 32 additions & 70 deletions modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strings"
"time"

"github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
"github.com/caddyserver/caddy/v2/modules/caddytls"

"github.com/caddyserver/caddy/v2"
Expand All @@ -34,6 +35,7 @@ func init() {
caddy.RegisterModule(Transport{})
}

// Transport facilitates FastCGI communication.
type Transport struct {
//////////////////////////////
// TODO: taken from v1 Handler type
Expand All @@ -57,32 +59,32 @@ type Transport struct {

// Use this directory as the fastcgi root directory. Defaults to the root
// directory of the parent virtual host.
Root string
Root string `json:"root,omitempty"`

// The path in the URL will be split into two, with the first piece ending
// with the value of SplitPath. The first piece will be assumed as the
// actual resource (CGI script) name, and the second piece will be set to
// PATH_INFO for the CGI script to use.
SplitPath string
SplitPath string `json:"split_path,omitempty"`

// If the URL ends with '/' (which indicates a directory), these index
// files will be tried instead.
IndexFiles []string
// IndexFiles []string

// Environment Variables
EnvVars [][2]string
EnvVars [][2]string `json:"env,omitempty"`

// Ignored paths
IgnoredSubPaths []string
// IgnoredSubPaths []string

// The duration used to set a deadline when connecting to an upstream.
DialTimeout time.Duration
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`

// The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration
ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`

// The duration used to set a deadline when sending to the FastCGI server.
WriteTimeout time.Duration
WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
}

// CaddyModule returns the Caddy module information.
Expand All @@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo {
}
}

// RoundTrip implements http.RoundTripper.
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// Create environment for CGI script
env, err := t.buildEnv(r)
if err != nil {
return nil, fmt.Errorf("building environment: %v", err)
}

// TODO:
// Connect to FastCGI gateway
// address, err := f.Address()
// if err != nil {
// return http.StatusBadGateway, err
// }
// network, address := parseAddress(address)
network, address := "tcp", r.URL.Host // TODO:

// TODO: doesn't dialer have a Timeout field?
ctx := context.Background()
if t.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, t.DialTimeout)
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout))
defer cancel()
}

// extract dial information from request (this
// should embedded by the reverse proxy)
network, address := "tcp", r.URL.Host
if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil {
dialInfo := dialInfoVal.(reverseproxy.DialInfo)
network = dialInfo.Network
address = dialInfo.Address
}

fcgiBackend, err := DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dialing backend: %v", err)
}
// fcgiBackend is closed when response body is closed (see clientCloser)
// fcgiBackend gets closed when response body is closed (see clientCloser)

// read/write timeouts
if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil {
if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil {
return nil, fmt.Errorf("setting read timeout: %v", err)
}
if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil {
if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil {
return nil, fmt.Errorf("setting write timeout: %v", err)
}

var resp *http.Response

var contentLength int64
// if ContentLength is already set
if r.ContentLength > 0 {
contentLength = r.ContentLength
} else {
contentLength := r.ContentLength
if contentLength == 0 {
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
}

var resp *http.Response
switch r.Method {
case "HEAD":
case http.MethodHead:
resp, err = fcgiBackend.Head(env)
case "GET":
case http.MethodGet:
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
case "OPTIONS":
case http.MethodOptions:
resp, err = fcgiBackend.Options(env)
default:
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
}

// TODO:
return resp, err

// Stuff brought over from v1 that might not be necessary here:

// if resp != nil && resp.Body != nil {
// defer resp.Body.Close()
// }

// if err != nil {
// if err, ok := err.(net.Error); ok && err.Timeout() {
// return http.StatusGatewayTimeout, err
// } else if err != io.EOF {
// return http.StatusBadGateway, err
// }
// }

// // Write response header
// writeHeader(w, resp)

// // Write the response body
// _, err = io.Copy(w, resp.Body)
// if err != nil {
// return http.StatusBadGateway, err
// }

// // Log any stderr output from upstream
// if fcgiBackend.stderr.Len() != 0 {
// // Remove trailing newline, error logger already does this.
// err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
// }

// // Normally we would return the status code if it is an error status (>= 400),
// // however, upstream FastCGI apps don't know about our contract and have
// // probably already written an error page. So we just return 0, indicating
// // that the response body is already written. However, we do return any
// // error value so it can be logged.
// // Note that the proxy middleware works the same way, returning status=0.
// return 0, err
}

// buildEnv returns a set of CGI environment variables for the request.
Expand Down
Loading

0 comments on commit 0830fba

Please sign in to comment.