Skip to content

Commit

Permalink
fix(pia): support port forwarding using Wireguard (#2420)
Browse files Browse the repository at this point in the history
- Build API IP address using the first 2 bytes of the gateway IP and adding `128.1` to it
- API IP address is valid for both OpenVPN and Wireguard
- Fix #2320
  • Loading branch information
qdm12 authored Aug 19, 2024
1 parent b3cc278 commit c39edb6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
18 changes: 13 additions & 5 deletions internal/configuration/settings/serverselection.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
}

if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
// Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
if vpnServiceProvider == providers.Custom {
switch len(settings.Names) {
case 0:
case 1:
// Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
default:
return fmt.Errorf("%w: %d names specified instead of "+
"0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names))
}
}
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
if err != nil {
Expand Down
33 changes: 23 additions & 10 deletions internal/provider/privateinternetaccess/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
}

serverName := objects.ServerName

apiIP := buildAPIIPAddress(objects.Gateway)
logger := objects.Logger

if !objects.CanPortForward {
Expand Down Expand Up @@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,

if !dataFound || expired {
client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
p.portForwardPath, objects.Username, objects.Password)
if err != nil {
return nil, fmt.Errorf("refreshing port forward data: %w", err)
Expand All @@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))

// First time binding
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
return nil, fmt.Errorf("binding port: %w", err)
}

Expand All @@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context,
panic("gateway is not set")
}

apiIP := buildAPIIPAddress(objects.Gateway)

privateIPClient, err := newHTTPClient(objects.ServerName)
if err != nil {
return fmt.Errorf("creating custom HTTP client: %w", err)
Expand Down Expand Up @@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
return ctx.Err()
case <-keepAliveTimer.C:
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
err = bindPort(ctx, privateIPClient, apiIP, data)
if err != nil {
return fmt.Errorf("binding port: %w", err)
}
Expand All @@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
}

func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
if gateway.Is6() {
panic("IPv6 gateway not supported")
}

gatewayBytes := gateway.As4()
gatewayBytes[2] = 128
gatewayBytes[3] = 1
return netip.AddrFrom4(gatewayBytes)
}

func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, username, password)
if err != nil {
return data, fmt.Errorf("fetching token: %w", err)
}

data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
if err != nil {
return data, fmt.Errorf("fetching port forwarding data: %w", err)
}
Expand Down Expand Up @@ -286,15 +299,15 @@ func fetchToken(ctx context.Context, client *http.Client,
return result.Token, nil
}

func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}

queryParams := make(url.Values)
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Host: net.JoinHostPort(apiIP.String(), "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
Expand Down Expand Up @@ -340,7 +353,7 @@ var (
ErrBadResponse = errors.New("bad response received")
)

func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return fmt.Errorf("serializing payload: %w", err)
Expand All @@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
queryParams.Add("signature", data.Signature)
bindPortURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}
Expand Down

0 comments on commit c39edb6

Please sign in to comment.