diff --git a/psiphon/feedback.go b/psiphon/feedback.go index daadf0f64..e6884f549 100644 --- a/psiphon/feedback.go +++ b/psiphon/feedback.go @@ -166,6 +166,7 @@ func SendFeedback(ctx context.Context, config *Config, diagnostics, uploadPath s // redefines ResolveIP such that the corresponding fronting // provider ID is passed into UntunneledResolveIP to enable the use // of pre-resolved IPs. + // TODO: do not use pre-resolved IPs when tunneled. IPs, err := UntunneledResolveIP( ctx, config, resolver, hostname, "") if err != nil { diff --git a/psiphon/net.go b/psiphon/net.go index e708e8dc5..80ddef7cf 100644 --- a/psiphon/net.go +++ b/psiphon/net.go @@ -391,7 +391,7 @@ func UntunneledResolveIP( return IPs, nil } -// makeUntunneledFrontedHTTPClient returns a net/http.Client which is +// makeFrontedHTTPClient returns a net/http.Client which is // configured to use domain fronting and custom dialing features -- including // BindToDevice, etc. One or more fronting specs must be provided, i.e. // len(frontingSpecs) must be greater than 0. A function is returned which, @@ -400,10 +400,11 @@ func UntunneledResolveIP( // // The context is applied to underlying TCP dials. The caller is responsible // for applying the context to requests made with the returned http.Client. -func makeUntunneledFrontedHTTPClient( +func makeFrontedHTTPClient( ctx context.Context, config *Config, - untunneledDialConfig *DialConfig, + tunneled bool, + dialConfig *DialConfig, frontingSpecs parameters.FrontingSpecs, selectedFrontingProviderID func(string), skipVerify bool, @@ -499,26 +500,31 @@ func makeUntunneledFrontedHTTPClient( var resolvedIPAddress atomic.Value resolvedIPAddress.Store("") - // The default untunneled dial config does not support pre-resolved IPs so - // redefine the dial config to override ResolveIP with an implementation - // that enables their use by passing the fronting provider ID into - // UntunneledResolveIP. - meekDialConfig := &DialConfig{ - UpstreamProxyURL: untunneledDialConfig.UpstreamProxyURL, - CustomHeaders: untunneledDialConfig.CustomHeaders, - DeviceBinder: untunneledDialConfig.DeviceBinder, - IPv6Synthesizer: untunneledDialConfig.IPv6Synthesizer, - ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) { - IPs, err := UntunneledResolveIP( - ctx, config, config.GetResolver(), hostname, frontingProviderID) - if err != nil { - return nil, errors.Trace(err) - } - return IPs, nil - }, - ResolvedIPCallback: func(IPAddress string) { - resolvedIPAddress.Store(IPAddress) - }, + var meekDialConfig *DialConfig + if tunneled { + meekDialConfig = dialConfig + } else { + // The default untunneled dial config does not support pre-resolved IPs so + // redefine the dial config to override ResolveIP with an implementation + // that enables their use by passing the fronting provider ID into + // UntunneledResolveIP. + meekDialConfig = &DialConfig{ + UpstreamProxyURL: dialConfig.UpstreamProxyURL, + CustomHeaders: dialConfig.CustomHeaders, + DeviceBinder: dialConfig.DeviceBinder, + IPv6Synthesizer: dialConfig.IPv6Synthesizer, + ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) { + IPs, err := UntunneledResolveIP( + ctx, config, config.GetResolver(), hostname, frontingProviderID) + if err != nil { + return nil, errors.Trace(err) + } + return IPs, nil + }, + ResolvedIPCallback: func(IPAddress string) { + resolvedIPAddress.Store(IPAddress) + }, + } } selectedUserAgent, userAgent := selectUserAgentIfUnset(p, meekDialConfig.CustomHeaders) @@ -654,9 +660,10 @@ func MakeUntunneledHTTPClient( // Ignore skipVerify because it only applies when there are no // fronting specs. - httpClient, getParams, err := makeUntunneledFrontedHTTPClient( + httpClient, getParams, err := makeFrontedHTTPClient( ctx, config, + false, untunneledDialConfig, frontingSpecs, selectedFrontingProviderID, @@ -704,9 +711,13 @@ func MakeUntunneledHTTPClient( // dialing and, optionally, UseTrustedCACertificatesForStockTLS. // This http.Client uses stock TLS for HTTPS. func MakeTunneledHTTPClient( + ctx context.Context, config *Config, tunnel *Tunnel, - skipVerify bool) (*http.Client, error) { + skipVerify bool, + disableSystemRootCAs bool, + frontingSpecs parameters.FrontingSpecs, + selectedFrontingProviderID func(string)) (*http.Client, func() common.APIParameters, error) { // Note: there is no dial context since SSH port forward dials cannot // be interrupted directly. Closing the tunnel will interrupt the dials. @@ -718,6 +729,32 @@ func MakeTunneledHTTPClient( return conn, errors.Trace(err) } + if len(frontingSpecs) > 0 { + + dialConfig := &DialConfig{ + TrustedCACertificatesFilename: config.TrustedCACertificatesFilename, + CustomDialer: func(_ context.Context, _, addr string) (net.Conn, error) { + return tunneledDialer("", addr) + }, + } + + // Ignore skipVerify because it only applies when there are no + // fronting specs. + httpClient, getParams, err := makeFrontedHTTPClient( + ctx, + config, + true, + dialConfig, + frontingSpecs, + selectedFrontingProviderID, + false, + disableSystemRootCAs) + if err != nil { + return nil, nil, errors.Trace(err) + } + return httpClient, getParams, nil + } + transport := &http.Transport{ Dial: tunneledDialer, } @@ -731,7 +768,7 @@ func MakeTunneledHTTPClient( rootCAs := x509.NewCertPool() certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } rootCAs.AppendCertsFromPEM(certData) transport.TLSClientConfig = &tls.Config{RootCAs: rootCAs} @@ -739,7 +776,7 @@ func MakeTunneledHTTPClient( return &http.Client{ Transport: transport, - }, nil + }, nil, nil } // MakeDownloadHTTPClient is a helper that sets up a http.Client for use either @@ -766,8 +803,14 @@ func MakeDownloadHTTPClient( if tunneled { - httpClient, err = MakeTunneledHTTPClient( - config, tunnel, skipVerify || disableSystemRootCAs) + httpClient, getParams, err = MakeTunneledHTTPClient( + ctx, + config, + tunnel, + skipVerify || disableSystemRootCAs, + disableSystemRootCAs, + frontingSpecs, + selectedFrontingProviderID) if err != nil { return nil, false, nil, errors.Trace(err) }