Skip to content

Commit

Permalink
Provide proxy listener mode from reversetunnel.Resolver (#16434)
Browse files Browse the repository at this point in the history
By only providing the tunnel address from the `reversetunnel.Resolver`
callers would still need to lookup the proxy listener mode to determine
how to dial the address. This results in sending a request to
`/webapi/find` once by the resolver to get the tunnel address and then
a second request to `/webapi/find` by users of the `Resolver` to determine
the proxy listener mode. Propagating the listener mode along with the
tunnel address by the `Resolver` ensures only one `/webapi/find` call
is needed.

This is especially impactful because the `reversetunnel.TunnelAuthDialer`
which is used by the auth http client would do this everytime the
`http.Client` connection pool was empty. When the `http.Client` needed
to dial the auth server it was incurring the additional roundtrip to the
proxy.
  • Loading branch information
rosstimothy committed Oct 25, 2022
1 parent 1831ee8 commit 672ac37
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 195 deletions.
25 changes: 20 additions & 5 deletions api/client/contextdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration)
func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout)
return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
if err != nil {
return nil, trace.Wrap(err)
}

tunnelAddr, err := resp.Proxy.TunnelAddr()
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -103,6 +107,7 @@ func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dura
if err != nil {
return nil, trace.Wrap(err)
}

return conn, nil
})
}
Expand All @@ -128,25 +133,34 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur
// through the SSH reverse tunnel on the proxy.
func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
if err != nil {
return nil, trace.Wrap(err)
}

if !resp.Proxy.TLSRoutingEnabled {
return nil, trace.NotImplemented("TLS routing is not enabled")
}

tunnelAddr, err := resp.Proxy.TunnelAddr()
if err != nil {
return nil, trace.Wrap(err)
}

dialer := &net.Dialer{
Timeout: dialTimeout,
KeepAlive: keepAlivePeriod,
}
conn, err = dialer.DialContext(ctx, network, tunnelAddr)
if err != nil {
return nil, trace.Wrap(err)

}

host, _, err := webclient.ParseHostPort(tunnelAddr)
if err != nil {
return nil, trace.Wrap(err)
}

tlsConn := tls.Client(conn, &tls.Config{
NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel},
InsecureSkipVerify: insecure,
Expand All @@ -160,6 +174,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeou
if err != nil {
return nil, trace.Wrap(err)
}

return sconn, nil
})
}
Expand Down
70 changes: 38 additions & 32 deletions api/client/webclient/webclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ import (
"strings"
"time"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/net/http/httpproxy"

"github.com/gravitational/teleport/api/client/proxy"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/utils"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"golang.org/x/net/http/httpproxy"
)

// Config specifies information when building requests with the
Expand All @@ -65,6 +66,8 @@ type Config struct {
IgnoreHTTPProxy bool
// Timeout is a timeout for requests.
Timeout time.Duration
// TraceProvider is used to retrieve a Tracer for creating spans
TraceProvider oteltrace.TracerProvider
}

// CheckAndSetDefaults checks and sets defaults
Expand All @@ -79,6 +82,9 @@ func (c *Config) CheckAndSetDefaults() error {
if c.Timeout == 0 {
c.Timeout = defaults.DefaultDialTimeout
}
if c.TraceProvider == nil {
c.TraceProvider = tracing.DefaultProvider()
}
return nil
}

Expand Down Expand Up @@ -114,13 +120,16 @@ func newWebClient(cfg *Config) (*http.Client, error) {
// If these conditions are not met, then the plain-HTTP fallback is not allowed,
// and a the HTTPS failure will be considered final.
func doWithFallback(clt *http.Client, allowPlainHTTP bool, extraHeaders map[string]string, req *http.Request) (*http.Response, error) {
span := oteltrace.SpanFromContext(req.Context())

// first try https and see how that goes
req.URL.Scheme = "https"
for k, v := range extraHeaders {
req.Header.Add(k, v)
}

log.Debugf("Attempting %s %s%s", req.Method, req.URL.Host, req.URL.Path)
span.AddEvent("sending https request")
resp, err := clt.Do(req)

// If the HTTPS succeeds, return that.
Expand All @@ -139,6 +148,7 @@ func doWithFallback(clt *http.Client, allowPlainHTTP bool, extraHeaders map[stri
// clear-text HTTP to see if that works.
req.URL.Scheme = "http"
log.Warnf("Request for %s %s%s falling back to PLAIN HTTP", req.Method, req.URL.Host, req.URL.Path)
span.AddEvent("falling back to http request")
resp, err = clt.Do(req)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -156,9 +166,12 @@ func Find(cfg *Config) (*PingResponse, error) {
}
defer clt.CloseIdleConnections()

ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/Find")
defer span.End()

endpoint := fmt.Sprintf("https://%s/webapi/find", cfg.ProxyAddr)

req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -189,12 +202,15 @@ func Ping(cfg *Config) (*PingResponse, error) {
}
defer clt.CloseIdleConnections()

ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/Ping")
defer span.End()

endpoint := fmt.Sprintf("https://%s/webapi/ping", cfg.ProxyAddr)
if cfg.ConnectorName != "" {
endpoint = fmt.Sprintf("%s/%s", endpoint, cfg.ConnectorName)
}

req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -219,40 +235,19 @@ func Ping(cfg *Config) (*PingResponse, error) {
return pr, nil
}

// GetTunnelAddr returns the tunnel address either set in an environment variable or retrieved from the web proxy.
func GetTunnelAddr(cfg *Config) (string, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return "", trace.Wrap(err)
}
// If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it.
if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" {
return parseAndJoinHostPort(tunnelAddr)
}

// Ping web proxy to retrieve tunnel proxy address.
pr, err := Find(cfg)
if err != nil {
return "", trace.Wrap(err)
}
// DELETE IN 11.0.0
// newer proxies should return WebListenAddr so
// we don't need to rely on the dialed proxyAddr
if pr.Proxy.SSH.WebListenAddr == "" {
pr.Proxy.SSH.WebListenAddr = cfg.ProxyAddr
}
return pr.Proxy.tunnelProxyAddr()
}

func GetMOTD(cfg *Config) (*MotD, error) {
clt, err := newWebClient(cfg)
if err != nil {
return nil, trace.Wrap(err)
}
defer clt.CloseIdleConnections()

ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/GetMOTD")
defer span.End()

endpoint := fmt.Sprintf("https://%s/webapi/motd", cfg.ProxyAddr)

req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -434,6 +429,17 @@ type GithubSettings struct {
Display string `json:"display"`
}

func (ps *ProxySettings) TunnelAddr() (string, error) {
// If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it.
if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" {
addr, err := parseAndJoinHostPort(tunnelAddr)
return addr, trace.Wrap(err)
}

addr, err := ps.tunnelProxyAddr()
return addr, trace.Wrap(err)
}

// tunnelProxyAddr returns the tunnel proxy address for the proxy settings.
func (ps *ProxySettings) tunnelProxyAddr() (string, error) {
if ps.TLSRoutingEnabled {
Expand Down
Loading

0 comments on commit 672ac37

Please sign in to comment.