Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
]
},
"git.ignoreLimitWarning": true,
"cmake.sourceDirectory": "/workspaces/kvm-static-ip/internal/native/cgo"
"cmake.sourceDirectory": "/workspaces/kvm-static-ip/internal/native/cgo",
"cmake.ignoreCMakeListsMissing": true
}
6 changes: 6 additions & 0 deletions jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,10 @@ func rpcSetCloudUrl(apiUrl string, appUrl string) error {
disconnectCloud(fmt.Errorf("cloud url changed from %s to %s", currentCloudURL, apiUrl))
}

if publicIPState != nil {
publicIPState.SetCloudflareEndpoint(apiUrl)
}

if err := SaveConfig(); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
Expand Down Expand Up @@ -1248,4 +1252,6 @@ var rpcHandlers = map[string]RPCHandler{
"setKeyboardMacros": {Func: setKeyboardMacros, Params: []string{"params"}},
"getLocalLoopbackOnly": {Func: rpcGetLocalLoopbackOnly},
"setLocalLoopbackOnly": {Func: rpcSetLocalLoopbackOnly, Params: []string{"enabled"}},
"getPublicIPAddresses": {Func: rpcGetPublicIPAddresses, Params: []string{"refresh"}},
"checkPublicIPAddresses": {Func: rpcCheckPublicIPAddresses},
}
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func Main() {

// As websocket client already checks if the cloud token is set, we can start it here.
go RunWebsocketClient()
initPublicIPState()

initSerialPort()
sigs := make(chan os.Signal, 1)
Expand Down
71 changes: 71 additions & 0 deletions network.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package kvm
import (
"context"
"fmt"
"net"
"net/http"
"reflect"
"time"

"github.com/jetkvm/kvm/internal/confparser"
"github.com/jetkvm/kvm/internal/mdns"
"github.com/jetkvm/kvm/internal/network/types"
"github.com/jetkvm/kvm/pkg/myip"
"github.com/jetkvm/kvm/pkg/nmlite"
"github.com/jetkvm/kvm/pkg/nmlite/link"
)

const (
Expand All @@ -17,6 +22,7 @@ const (

var (
networkManager *nmlite.NetworkManager
publicIPState *myip.PublicIPState
)

type RpcNetworkSettings struct {
Expand Down Expand Up @@ -104,6 +110,13 @@ func triggerTimeSyncOnNetworkStateChange() {
}()
}

func setPublicIPReadyState(ipv4Ready, ipv6Ready bool) {
if publicIPState == nil {
return
}
publicIPState.SetIPv4AndIPv6(ipv4Ready, ipv6Ready)
}

func networkStateChanged(_ string, state types.InterfaceState) {
// do not block the main thread
go waitCtrlAndRequestDisplayUpdate(true, "network_state_changed")
Expand All @@ -117,6 +130,8 @@ func networkStateChanged(_ string, state types.InterfaceState) {
triggerTimeSyncOnNetworkStateChange()
}

setPublicIPReadyState(state.IPv4Ready, state.IPv6Ready)

// always restart mDNS when the network state changes
if mDNS != nil {
restartMdns()
Expand Down Expand Up @@ -164,6 +179,40 @@ func initNetwork() error {
return nil
}

func initPublicIPState() {
// the feature will be only enabled if the cloud has been adopted
// due to privacy reasons

// but it will be initialized anyway to avoid nil pointer dereferences
ps := myip.NewPublicIPState(&myip.PublicIPStateConfig{
Logger: networkLogger,
CloudflareEndpoint: config.CloudURL,
APIEndpoint: "",
IPv4: false,
IPv6: false,
HttpClientGetter: func(family int) *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Proxy = config.NetworkConfig.GetTransportProxyFunc()
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
netType := network
switch family {
case link.AfInet:
netType = "tcp4"
case link.AfInet6:
netType = "tcp6"
}
return (&net.Dialer{}).DialContext(ctx, netType, addr)
}

return &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
},
})
publicIPState = ps
}

func setHostname(nm *nmlite.NetworkManager, hostname, domain string) error {
if nm == nil {
return nil
Expand Down Expand Up @@ -312,3 +361,25 @@ func rpcToggleDHCPClient() error {

return rpcReboot(true)
}

func rpcGetPublicIPAddresses(refresh bool) ([]myip.PublicIP, error) {
if publicIPState == nil {
return nil, fmt.Errorf("public IP state not initialized")
}

if refresh {
if err := publicIPState.ForceUpdate(); err != nil {
return nil, err
}
}

return publicIPState.GetAddresses(), nil
}

func rpcCheckPublicIPAddresses() error {
if publicIPState == nil {
return fmt.Errorf("public IP state not initialized")
}

return publicIPState.ForceUpdate()
}
160 changes: 160 additions & 0 deletions pkg/myip/check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package myip

import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/jetkvm/kvm/pkg/nmlite/link"
)

func (ps *PublicIPState) request(ctx context.Context, url string, family int) ([]byte, error) {
client := ps.httpClient(family)

req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error sending request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}

return body, err
}

// checkCloudflare uses cdn-cgi/trace to get the public IP address
func (ps *PublicIPState) checkCloudflare(ctx context.Context, family int) (*PublicIP, error) {
u, err := url.JoinPath(ps.cloudflareEndpoint, "/cdn-cgi/trace")
if err != nil {
return nil, fmt.Errorf("error joining path: %w", err)
}

body, err := ps.request(ctx, u, family)
if err != nil {
return nil, err
}

values := make(map[string]string)
for line := range strings.SplitSeq(string(body), "\n") {
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
values[key] = value
}

ps.lastUpdated = time.Now()
if ts, ok := values["ts"]; ok {
if ts, err := strconv.ParseFloat(ts, 64); err == nil {
ps.lastUpdated = time.Unix(int64(ts), 0)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We silently discard any failure to parse ts but that's probably okay...

}

ipStr, ok := values["ip"]
if !ok {
return nil, fmt.Errorf("no IP address found")
}

ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", ipStr)
}

return &PublicIP{
IPAddress: ip,
LastUpdated: ps.lastUpdated,
}, nil
}

// checkAPI uses the API endpoint to get the public IP address
func (ps *PublicIPState) checkAPI(_ context.Context, _ int) (*PublicIP, error) {
return nil, fmt.Errorf("not implemented")
}

// checkIPs checks both IPv4 and IPv6 public IP addresses in parallel
// and updates the IPAddresses slice with the results
func (ps *PublicIPState) checkIPs(ctx context.Context, checkIPv4, checkIPv6 bool) error {
var wg sync.WaitGroup
var mu sync.Mutex
var ips []PublicIP
var errors []error

checkFamily := func(family int, familyName string) {
wg.Add(1)
go func(f int, name string) {
defer wg.Done()

ip, err := ps.checkIPForFamily(ctx, f)
mu.Lock()
defer mu.Unlock()
if err != nil {
errors = append(errors, fmt.Errorf("%s check failed: %w", name, err))
return
}
if ip != nil {
ips = append(ips, *ip)
}
}(family, familyName)
}

if checkIPv4 {
checkFamily(link.AfInet, "IPv4")
}

if checkIPv6 {
checkFamily(link.AfInet6, "IPv6")
}

wg.Wait()

if len(ips) > 0 {
ps.mu.Lock()
defer ps.mu.Unlock()

ps.addresses = ips
ps.lastUpdated = time.Now()
}

if len(errors) > 0 && len(ips) == 0 {
return errors[0]
}

return nil
}

func (ps *PublicIPState) checkIPForFamily(ctx context.Context, family int) (*PublicIP, error) {
if ps.apiEndpoint != "" {
ip, err := ps.checkAPI(ctx, family)
if err == nil && ip != nil {
return ip, nil
}
}

if ps.cloudflareEndpoint != "" {
ip, err := ps.checkCloudflare(ctx, family)
if err == nil && ip != nil {
return ip, nil
}
}

return nil, fmt.Errorf("all IP check methods failed for family %d", family)
}
Loading