Skip to content

Commit

Permalink
Use FrontingSpecs for tunneled downloads
Browse files Browse the repository at this point in the history
  • Loading branch information
mirokuratczyk committed Oct 20, 2023
1 parent 68fbee9 commit e39606e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
1 change: 1 addition & 0 deletions psiphon/feedback.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ func SendFeedback(ctx context.Context, config *Config, diagnostics, uploadPath s
// redefines ResolveIP such that the corresponding fronting
// provider ID is passed into UntunneledResolveIP to enable the use
// of pre-resolved IPs.
// TODO: do not use pre-resolved IPs when tunneled.
IPs, err := UntunneledResolveIP(
ctx, config, resolver, hostname, "")
if err != nil {
Expand Down
101 changes: 72 additions & 29 deletions psiphon/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func UntunneledResolveIP(
return IPs, nil
}

// makeUntunneledFrontedHTTPClient returns a net/http.Client which is
// makeFrontedHTTPClient returns a net/http.Client which is
// configured to use domain fronting and custom dialing features -- including
// BindToDevice, etc. One or more fronting specs must be provided, i.e.
// len(frontingSpecs) must be greater than 0. A function is returned which,
Expand All @@ -400,10 +400,11 @@ func UntunneledResolveIP(
//
// The context is applied to underlying TCP dials. The caller is responsible
// for applying the context to requests made with the returned http.Client.
func makeUntunneledFrontedHTTPClient(
func makeFrontedHTTPClient(
ctx context.Context,
config *Config,
untunneledDialConfig *DialConfig,
tunneled bool,
dialConfig *DialConfig,
frontingSpecs parameters.FrontingSpecs,
selectedFrontingProviderID func(string),
skipVerify bool,
Expand Down Expand Up @@ -499,26 +500,31 @@ func makeUntunneledFrontedHTTPClient(
var resolvedIPAddress atomic.Value
resolvedIPAddress.Store("")

// The default untunneled dial config does not support pre-resolved IPs so
// redefine the dial config to override ResolveIP with an implementation
// that enables their use by passing the fronting provider ID into
// UntunneledResolveIP.
meekDialConfig := &DialConfig{
UpstreamProxyURL: untunneledDialConfig.UpstreamProxyURL,
CustomHeaders: untunneledDialConfig.CustomHeaders,
DeviceBinder: untunneledDialConfig.DeviceBinder,
IPv6Synthesizer: untunneledDialConfig.IPv6Synthesizer,
ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) {
IPs, err := UntunneledResolveIP(
ctx, config, config.GetResolver(), hostname, frontingProviderID)
if err != nil {
return nil, errors.Trace(err)
}
return IPs, nil
},
ResolvedIPCallback: func(IPAddress string) {
resolvedIPAddress.Store(IPAddress)
},
var meekDialConfig *DialConfig
if tunneled {
meekDialConfig = dialConfig
} else {
// The default untunneled dial config does not support pre-resolved IPs so
// redefine the dial config to override ResolveIP with an implementation
// that enables their use by passing the fronting provider ID into
// UntunneledResolveIP.
meekDialConfig = &DialConfig{
UpstreamProxyURL: dialConfig.UpstreamProxyURL,
CustomHeaders: dialConfig.CustomHeaders,
DeviceBinder: dialConfig.DeviceBinder,
IPv6Synthesizer: dialConfig.IPv6Synthesizer,
ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) {
IPs, err := UntunneledResolveIP(
ctx, config, config.GetResolver(), hostname, frontingProviderID)
if err != nil {
return nil, errors.Trace(err)
}
return IPs, nil
},
ResolvedIPCallback: func(IPAddress string) {
resolvedIPAddress.Store(IPAddress)
},
}
}

selectedUserAgent, userAgent := selectUserAgentIfUnset(p, meekDialConfig.CustomHeaders)
Expand Down Expand Up @@ -654,9 +660,10 @@ func MakeUntunneledHTTPClient(

// Ignore skipVerify because it only applies when there are no
// fronting specs.
httpClient, getParams, err := makeUntunneledFrontedHTTPClient(
httpClient, getParams, err := makeFrontedHTTPClient(
ctx,
config,
false,
untunneledDialConfig,
frontingSpecs,
selectedFrontingProviderID,
Expand Down Expand Up @@ -704,9 +711,13 @@ func MakeUntunneledHTTPClient(
// dialing and, optionally, UseTrustedCACertificatesForStockTLS.
// This http.Client uses stock TLS for HTTPS.
func MakeTunneledHTTPClient(
ctx context.Context,
config *Config,
tunnel *Tunnel,
skipVerify bool) (*http.Client, error) {
skipVerify bool,
disableSystemRootCAs bool,
frontingSpecs parameters.FrontingSpecs,
selectedFrontingProviderID func(string)) (*http.Client, func() common.APIParameters, error) {

// Note: there is no dial context since SSH port forward dials cannot
// be interrupted directly. Closing the tunnel will interrupt the dials.
Expand All @@ -718,6 +729,32 @@ func MakeTunneledHTTPClient(
return conn, errors.Trace(err)
}

if len(frontingSpecs) > 0 {

dialConfig := &DialConfig{
TrustedCACertificatesFilename: config.TrustedCACertificatesFilename,
CustomDialer: func(_ context.Context, _, addr string) (net.Conn, error) {
return tunneledDialer("", addr)
},
}

// Ignore skipVerify because it only applies when there are no
// fronting specs.
httpClient, getParams, err := makeFrontedHTTPClient(
ctx,
config,
true,
dialConfig,
frontingSpecs,
selectedFrontingProviderID,
false,
disableSystemRootCAs)
if err != nil {
return nil, nil, errors.Trace(err)
}
return httpClient, getParams, nil
}

transport := &http.Transport{
Dial: tunneledDialer,
}
Expand All @@ -731,15 +768,15 @@ func MakeTunneledHTTPClient(
rootCAs := x509.NewCertPool()
certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename)
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
rootCAs.AppendCertsFromPEM(certData)
transport.TLSClientConfig = &tls.Config{RootCAs: rootCAs}
}

return &http.Client{
Transport: transport,
}, nil
}, nil, nil
}

// MakeDownloadHTTPClient is a helper that sets up a http.Client for use either
Expand All @@ -766,8 +803,14 @@ func MakeDownloadHTTPClient(

if tunneled {

httpClient, err = MakeTunneledHTTPClient(
config, tunnel, skipVerify || disableSystemRootCAs)
httpClient, getParams, err = MakeTunneledHTTPClient(
ctx,
config,
tunnel,
skipVerify || disableSystemRootCAs,
disableSystemRootCAs,
frontingSpecs,
selectedFrontingProviderID)
if err != nil {
return nil, false, nil, errors.Trace(err)
}
Expand Down

0 comments on commit e39606e

Please sign in to comment.