Skip to content

Commit

Permalink
Merge pull request #710 from adam-p/network-id
Browse files Browse the repository at this point in the history
add networkid package and windows implementation
  • Loading branch information
rod-hynes committed Dec 5, 2024
1 parent 1393a17 commit 3061945
Show file tree
Hide file tree
Showing 8 changed files with 626 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ jobs:
go test -v -timeout 30m -race ./psiphon
go test -v -race ./ClientLibrary/clientlib
go test -v -race ./Server/logging/analysis
go test -v -race ./psiphon/common/networkid
# TODO: fix and re-enable test
# sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./psiphon/common/tun
Expand Down Expand Up @@ -133,6 +134,7 @@ jobs:
go test -v -timeout 30m -covermode=count -coverprofile=psiphon.coverprofile ./psiphon
go test -v -covermode=count -coverprofile=clientlib.coverprofile ./ClientLibrary/clientlib
go test -v -covermode=count -coverprofile=analysis.coverprofile ./Server/logging/analysis
go test -v -covermode=count -coverprofile=networkid.coverprofile ./psiphon/common/networkid
$GOPATH/bin/gover
$GOPATH/bin/goveralls -coverprofile=gover.coverprofile -service=github -repotoken "$COVERALLS_TOKEN"
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ require (
github.com/florianl/go-nfqueue v1.1.1-0.20200829120558-a2f196e98ab0
github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4
github.com/fxamacker/cbor/v2 v2.5.0
github.com/go-ole/go-ole v1.3.0
github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
github.com/google/gopacket v1.1.19
Expand Down Expand Up @@ -85,6 +86,7 @@ require (
golang.org/x/term v0.19.0
golang.org/x/time v0.5.0
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
tailscale.com v1.58.2
)

Expand All @@ -102,7 +104,6 @@ require (
github.com/dchest/siphash v1.2.3 // indirect
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect
github.com/gaukas/godicttls v0.0.4 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
Expand Down Expand Up @@ -152,7 +153,6 @@ require (
golang.org/x/mod v0.14.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.15.0 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
32 changes: 32 additions & 0 deletions psiphon/common/networkid/networkid_disabled.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//go:build !windows

/*
* Copyright (c) 2024, Psiphon Inc.
* All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

package networkid

import "fmt"

func Enabled() bool {
return false
}

func Get() (string, error) {
return "", fmt.Errorf("operation is not enabled")
}
300 changes: 300 additions & 0 deletions psiphon/common/networkid/networkid_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
/*
* Copyright (c) 2024, Psiphon Inc.
* All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

package networkid

import (
"net"
"net/netip"
"runtime"
"strings"
"sync"
"syscall"
"time"
"unsafe"

"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
"github.com/go-ole/go-ole"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/wgengine/winnet"
)

func Enabled() bool {
return true
}

// Get address associated with the default interface.
func getDefaultLocalAddr() (net.IP, error) {
// Note that this function has no Windows-specific code and could be used elsewhere.

// This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces()
// The basic idea is that we initialize a UDP connection and see what local
// address the system decides to use.
// Note that no actual network request is made by these calls. They can be performed
// with no network connectivity at all.
// TODO: Use common test IP addresses in that function and this.

// We'll prefer IPv4 and check it first (both might be available)
ipv4UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("93.184.216.34:3478"))
ipv4UDPConn, ipv4Err := net.DialUDP("udp4", nil, ipv4UDPAddr)
if ipv4Err == nil {
ip := ipv4UDPConn.LocalAddr().(*net.UDPAddr).IP
ipv4UDPConn.Close()
return ip, nil
}

ipv6UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("[2606:2800:220:1:248:1893:25c8:1946]:3478"))
ipv6UDPConn, ipv6Err := net.DialUDP("udp6", nil, ipv6UDPAddr)
if ipv6Err == nil {
ip := ipv6UDPConn.LocalAddr().(*net.UDPAddr).IP
ipv6UDPConn.Close()
return ip, nil
}

return nil, errors.Trace(ipv4Err)
}

// Given the IP of a local interface, get that interface info.
func getInterfaceForLocalIP(ip net.IP) (*net.Interface, error) {
// Note that this function has no Windows-specific code and could be used elsewhere.

ifaces, err := net.Interfaces()
if err != nil {
return nil, errors.Trace(err)
}

for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, errors.Trace(err)
}

for _, addr := range addrs {
addrIP, _, err := net.ParseCIDR(addr.String())
if err != nil {
return nil, errors.Trace(err)
}

if addrIP.Equal(ip) {
return &iface, nil
}
}
}

return nil, errors.TraceNew("not found")
}

// Given the interface index, get info about the interface and its network.
func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg.IfType, err error) {
luid, err := winipcfg.LUIDFromIndex(uint32(index))
if err != nil {
return "", "", 0, errors.Trace(err)
}

ifrow, err := luid.Interface()
if err != nil {
return "", "", 0, errors.Trace(err)
}

description = ifrow.Description() + " " + ifrow.Alias()

ifType = ifrow.Type

var c ole.Connection
nlm, err := winnet.NewNetworkListManager(&c)
if err != nil {
return "", "", 0, errors.Trace(err)
}
defer nlm.Release()

netConns, err := nlm.GetNetworkConnections()
if err != nil {
return "", "", 0, errors.Trace(err)
}
defer netConns.Release()

for _, nc := range netConns {
ncAdapterID, err := nc.GetAdapterId()
if err != nil {
return "", "", 0, errors.Trace(err)
}
if ncAdapterID != ifrow.InterfaceGUID.String() {
continue
}

// Found the INetworkConnection for the target adapter.
// Get its network and network ID.

n, err := nc.GetNetwork()
if err != nil {
return "", "", 0, errors.Trace(err)
}
defer n.Release()

guid := ole.GUID{}
hr, _, _ := syscall.SyscallN(
n.VTable().GetNetworkId,
uintptr(unsafe.Pointer(n)),
uintptr(unsafe.Pointer(&guid)))
if hr != 0 {
return "", "", 0, errors.Tracef("GetNetworkId failed: %08x", hr)
}

networkID = guid.String()
return networkID, description, ifType, nil
}

return "", "", 0, errors.Tracef("network connection not found for interface %d", index)
}

// Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
// interface type and description.
// If the correct connection type can not be determined, "UNKNOWN" will be returned.
func getConnectionType(ifType winipcfg.IfType, description string) string {
var connectionType string

switch ifType {
case winipcfg.IfTypeEthernetCSMACD:
connectionType = "WIRED"
case winipcfg.IfTypeIEEE80211:
connectionType = "WIFI"
case winipcfg.IfTypeWwanpp, winipcfg.IfTypeWwanpp2:
connectionType = "MOBILE"
case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
connectionType = "VPN"
default:
connectionType = "UNKNOWN"
}

if connectionType != "VPN" {
// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
// back to checking for certain words in the description. This feels like a hack,
// but research suggests that it's the best we can do.

description = strings.ToLower(description)
if strings.Contains(description, "vpn") ||
strings.Contains(description, "tunnel") ||
strings.Contains(description, "virtual") ||
strings.Contains(description, "tap") ||
strings.Contains(description, "l2tp") ||
strings.Contains(description, "sstp") ||
strings.Contains(description, "pptp") ||
strings.Contains(description, "openvpn") {
connectionType = "VPN"
}
}

return connectionType
}

func getNetworkID() (string, error) {
localAddr, err := getDefaultLocalAddr()
if err != nil {
return "", errors.Trace(err)
}

iface, err := getInterfaceForLocalIP(localAddr)
if err != nil {
return "", errors.Trace(err)
}

networkID, description, ifType, err := getInterfaceInfo(iface.Index)
if err != nil {
return "", errors.Trace(err)
}

connectionType := getConnectionType(ifType, description)

compoundID := connectionType + "-" + strings.Trim(networkID, "{}")

return compoundID, nil
}

type result struct {
networkID string
err error
}

var workThread struct {
init sync.Once
reqs chan (chan<- result)
err error

cachedResult string
cacheExpiry time.Time
}

// Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
// This function is safe to call concurrently from multiple goroutines.
// Note that if this function is called immediately after a network change (within ~2000ms)
// a transitory Network ID may be returned that will change on the next call. The caller
// may wish to delay responding to a new Network ID until the value is confirmed.
func Get() (string, error) {
// It is not clear if the COM NetworkListManager calls are threadsafe. We're using them
// read-only and they're probably fine, but we're not sure. Additionally, our networkID
// retrieval code is somewhat slow: 3.5ms. This function gets called by each connection
// attempt (in the horse race, etc.), so this extra time might add ~10% to a such an
// attempt. The value is very unlikely to change in a short amount of time, so it seems
// like a good optimization to cache the result. We'll restrict our work to single
// thread to achieve both goals.
workThread.init.Do(func() {
workThread.reqs = make(chan (chan<- result))

go func() {
const resultCacheDuration = 500 * time.Millisecond

// Go can switch the execution of a goroutine from one OS thread to another
// at (almost) any time. This may or may not be risky to do for our win32
// (and especially COM) calls, so we're going to explicitly lock this goroutine
// to a single OS thread. This shouldn't have any real impact on performance
// and will help protect against difficult-to-reproduce errors.
runtime.LockOSThread()
defer runtime.UnlockOSThread()

if err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED); err != nil {
workThread.err = errors.Trace(err)
close(workThread.reqs)
return
}
defer windows.CoUninitialize()

for resCh := range workThread.reqs {
if workThread.cachedResult != "" && workThread.cacheExpiry.After(time.Now()) {
resCh <- result{workThread.cachedResult, nil}
} else {
networkID, err := getNetworkID()
resCh <- result{networkID, err}
workThread.cachedResult = networkID
workThread.cacheExpiry = time.Now().Add(resultCacheDuration)
}
}
}()
})

resCh := make(chan result)
workThread.reqs <- resCh
res := <-resCh

if res.err != nil {
return "", errors.Trace(res.err)
}

return res.networkID, nil
}
Loading

0 comments on commit 3061945

Please sign in to comment.