-
Notifications
You must be signed in to change notification settings - Fork 269
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #710 from adam-p/network-id
add networkid package and windows implementation
- Loading branch information
Showing
8 changed files
with
626 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
//go:build !windows | ||
|
||
/* | ||
* Copyright (c) 2024, Psiphon Inc. | ||
* All rights reserved. | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
* | ||
*/ | ||
|
||
package networkid | ||
|
||
import "fmt" | ||
|
||
func Enabled() bool { | ||
return false | ||
} | ||
|
||
func Get() (string, error) { | ||
return "", fmt.Errorf("operation is not enabled") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
/* | ||
* Copyright (c) 2024, Psiphon Inc. | ||
* All rights reserved. | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
* | ||
*/ | ||
|
||
package networkid | ||
|
||
import ( | ||
"net" | ||
"net/netip" | ||
"runtime" | ||
"strings" | ||
"sync" | ||
"syscall" | ||
"time" | ||
"unsafe" | ||
|
||
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" | ||
"github.com/go-ole/go-ole" | ||
"golang.org/x/sys/windows" | ||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" | ||
"tailscale.com/wgengine/winnet" | ||
) | ||
|
||
func Enabled() bool { | ||
return true | ||
} | ||
|
||
// Get address associated with the default interface. | ||
func getDefaultLocalAddr() (net.IP, error) { | ||
// Note that this function has no Windows-specific code and could be used elsewhere. | ||
|
||
// This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces() | ||
// The basic idea is that we initialize a UDP connection and see what local | ||
// address the system decides to use. | ||
// Note that no actual network request is made by these calls. They can be performed | ||
// with no network connectivity at all. | ||
// TODO: Use common test IP addresses in that function and this. | ||
|
||
// We'll prefer IPv4 and check it first (both might be available) | ||
ipv4UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("93.184.216.34:3478")) | ||
ipv4UDPConn, ipv4Err := net.DialUDP("udp4", nil, ipv4UDPAddr) | ||
if ipv4Err == nil { | ||
ip := ipv4UDPConn.LocalAddr().(*net.UDPAddr).IP | ||
ipv4UDPConn.Close() | ||
return ip, nil | ||
} | ||
|
||
ipv6UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("[2606:2800:220:1:248:1893:25c8:1946]:3478")) | ||
ipv6UDPConn, ipv6Err := net.DialUDP("udp6", nil, ipv6UDPAddr) | ||
if ipv6Err == nil { | ||
ip := ipv6UDPConn.LocalAddr().(*net.UDPAddr).IP | ||
ipv6UDPConn.Close() | ||
return ip, nil | ||
} | ||
|
||
return nil, errors.Trace(ipv4Err) | ||
} | ||
|
||
// Given the IP of a local interface, get that interface info. | ||
func getInterfaceForLocalIP(ip net.IP) (*net.Interface, error) { | ||
// Note that this function has no Windows-specific code and could be used elsewhere. | ||
|
||
ifaces, err := net.Interfaces() | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
for _, iface := range ifaces { | ||
addrs, err := iface.Addrs() | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
for _, addr := range addrs { | ||
addrIP, _, err := net.ParseCIDR(addr.String()) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
if addrIP.Equal(ip) { | ||
return &iface, nil | ||
} | ||
} | ||
} | ||
|
||
return nil, errors.TraceNew("not found") | ||
} | ||
|
||
// Given the interface index, get info about the interface and its network. | ||
func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg.IfType, err error) { | ||
luid, err := winipcfg.LUIDFromIndex(uint32(index)) | ||
if err != nil { | ||
return "", "", 0, errors.Trace(err) | ||
} | ||
|
||
ifrow, err := luid.Interface() | ||
if err != nil { | ||
return "", "", 0, errors.Trace(err) | ||
} | ||
|
||
description = ifrow.Description() + " " + ifrow.Alias() | ||
|
||
ifType = ifrow.Type | ||
|
||
var c ole.Connection | ||
nlm, err := winnet.NewNetworkListManager(&c) | ||
if err != nil { | ||
return "", "", 0, errors.Trace(err) | ||
} | ||
defer nlm.Release() | ||
|
||
netConns, err := nlm.GetNetworkConnections() | ||
if err != nil { | ||
return "", "", 0, errors.Trace(err) | ||
} | ||
defer netConns.Release() | ||
|
||
for _, nc := range netConns { | ||
ncAdapterID, err := nc.GetAdapterId() | ||
if err != nil { | ||
return "", "", 0, errors.Trace(err) | ||
} | ||
if ncAdapterID != ifrow.InterfaceGUID.String() { | ||
continue | ||
} | ||
|
||
// Found the INetworkConnection for the target adapter. | ||
// Get its network and network ID. | ||
|
||
n, err := nc.GetNetwork() | ||
if err != nil { | ||
return "", "", 0, errors.Trace(err) | ||
} | ||
defer n.Release() | ||
|
||
guid := ole.GUID{} | ||
hr, _, _ := syscall.SyscallN( | ||
n.VTable().GetNetworkId, | ||
uintptr(unsafe.Pointer(n)), | ||
uintptr(unsafe.Pointer(&guid))) | ||
if hr != 0 { | ||
return "", "", 0, errors.Tracef("GetNetworkId failed: %08x", hr) | ||
} | ||
|
||
networkID = guid.String() | ||
return networkID, description, ifType, nil | ||
} | ||
|
||
return "", "", 0, errors.Tracef("network connection not found for interface %d", index) | ||
} | ||
|
||
// Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given | ||
// interface type and description. | ||
// If the correct connection type can not be determined, "UNKNOWN" will be returned. | ||
func getConnectionType(ifType winipcfg.IfType, description string) string { | ||
var connectionType string | ||
|
||
switch ifType { | ||
case winipcfg.IfTypeEthernetCSMACD: | ||
connectionType = "WIRED" | ||
case winipcfg.IfTypeIEEE80211: | ||
connectionType = "WIFI" | ||
case winipcfg.IfTypeWwanpp, winipcfg.IfTypeWwanpp2: | ||
connectionType = "MOBILE" | ||
case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel: | ||
connectionType = "VPN" | ||
default: | ||
connectionType = "UNKNOWN" | ||
} | ||
|
||
if connectionType != "VPN" { | ||
// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall | ||
// back to checking for certain words in the description. This feels like a hack, | ||
// but research suggests that it's the best we can do. | ||
|
||
description = strings.ToLower(description) | ||
if strings.Contains(description, "vpn") || | ||
strings.Contains(description, "tunnel") || | ||
strings.Contains(description, "virtual") || | ||
strings.Contains(description, "tap") || | ||
strings.Contains(description, "l2tp") || | ||
strings.Contains(description, "sstp") || | ||
strings.Contains(description, "pptp") || | ||
strings.Contains(description, "openvpn") { | ||
connectionType = "VPN" | ||
} | ||
} | ||
|
||
return connectionType | ||
} | ||
|
||
func getNetworkID() (string, error) { | ||
localAddr, err := getDefaultLocalAddr() | ||
if err != nil { | ||
return "", errors.Trace(err) | ||
} | ||
|
||
iface, err := getInterfaceForLocalIP(localAddr) | ||
if err != nil { | ||
return "", errors.Trace(err) | ||
} | ||
|
||
networkID, description, ifType, err := getInterfaceInfo(iface.Index) | ||
if err != nil { | ||
return "", errors.Trace(err) | ||
} | ||
|
||
connectionType := getConnectionType(ifType, description) | ||
|
||
compoundID := connectionType + "-" + strings.Trim(networkID, "{}") | ||
|
||
return compoundID, nil | ||
} | ||
|
||
type result struct { | ||
networkID string | ||
err error | ||
} | ||
|
||
var workThread struct { | ||
init sync.Once | ||
reqs chan (chan<- result) | ||
err error | ||
|
||
cachedResult string | ||
cacheExpiry time.Time | ||
} | ||
|
||
// Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details. | ||
// This function is safe to call concurrently from multiple goroutines. | ||
// Note that if this function is called immediately after a network change (within ~2000ms) | ||
// a transitory Network ID may be returned that will change on the next call. The caller | ||
// may wish to delay responding to a new Network ID until the value is confirmed. | ||
func Get() (string, error) { | ||
// It is not clear if the COM NetworkListManager calls are threadsafe. We're using them | ||
// read-only and they're probably fine, but we're not sure. Additionally, our networkID | ||
// retrieval code is somewhat slow: 3.5ms. This function gets called by each connection | ||
// attempt (in the horse race, etc.), so this extra time might add ~10% to a such an | ||
// attempt. The value is very unlikely to change in a short amount of time, so it seems | ||
// like a good optimization to cache the result. We'll restrict our work to single | ||
// thread to achieve both goals. | ||
workThread.init.Do(func() { | ||
workThread.reqs = make(chan (chan<- result)) | ||
|
||
go func() { | ||
const resultCacheDuration = 500 * time.Millisecond | ||
|
||
// Go can switch the execution of a goroutine from one OS thread to another | ||
// at (almost) any time. This may or may not be risky to do for our win32 | ||
// (and especially COM) calls, so we're going to explicitly lock this goroutine | ||
// to a single OS thread. This shouldn't have any real impact on performance | ||
// and will help protect against difficult-to-reproduce errors. | ||
runtime.LockOSThread() | ||
defer runtime.UnlockOSThread() | ||
|
||
if err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED); err != nil { | ||
workThread.err = errors.Trace(err) | ||
close(workThread.reqs) | ||
return | ||
} | ||
defer windows.CoUninitialize() | ||
|
||
for resCh := range workThread.reqs { | ||
if workThread.cachedResult != "" && workThread.cacheExpiry.After(time.Now()) { | ||
resCh <- result{workThread.cachedResult, nil} | ||
} else { | ||
networkID, err := getNetworkID() | ||
resCh <- result{networkID, err} | ||
workThread.cachedResult = networkID | ||
workThread.cacheExpiry = time.Now().Add(resultCacheDuration) | ||
} | ||
} | ||
}() | ||
}) | ||
|
||
resCh := make(chan result) | ||
workThread.reqs <- resCh | ||
res := <-resCh | ||
|
||
if res.err != nil { | ||
return "", errors.Trace(res.err) | ||
} | ||
|
||
return res.networkID, nil | ||
} |
Oops, something went wrong.