Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make IP Prefix configurable and available ip deterministic #72

Merged
merged 9 commits into from
Aug 3, 2021
Merged
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal
```
"server_url": "http://192.168.1.12:8080",
"listen_addr": "0.0.0.0:8080",
"ip_prefix": "100.64.0.0/10"
```

`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on.
`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `ip_prefix` is the IP prefix (range) in which IP addresses for nodes will be allocated.

```
"private_key_path": "private.key",
Expand Down
1 change: 1 addition & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
log.Println(err)
return
}
log.Printf("Assigning %s to %s", ip, m.Name)

m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String()
Expand Down
2 changes: 2 additions & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"golang.org/x/crypto/acme/autocert"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
Expand All @@ -24,6 +25,7 @@ type Config struct {
PrivateKeyPath string
DerpMap *tailcfg.DERPMap
EphemeralNodeInactivityTimeout time.Duration
IPPrefix netaddr.IPPrefix

DBtype string
DBpath string
Expand Down
5 changes: 4 additions & 1 deletion app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"gopkg.in/check.v1"
"inet.af/netaddr"
)

func Test(t *testing.T) {
Expand Down Expand Up @@ -36,7 +37,9 @@ func (s *Suite) ResetDB(c *check.C) {
if err != nil {
c.Fatal(err)
}
cfg := Config{}
cfg := Config{
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
}

h = Headscale{
cfg: cfg,
Expand Down
1 change: 1 addition & 0 deletions cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
IPAddress: "10.0.0.1",
}
h.db.Save(&m)

Expand Down
4 changes: 4 additions & 0 deletions cmd/headscale/cli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale"
"github.com/spf13/viper"
"gopkg.in/yaml.v2"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)

Expand All @@ -36,6 +37,8 @@ func LoadConfig(path string) error {
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")

viper.SetDefault("ip_prefix", "100.64.0.0/10")

err := viper.ReadInConfig()
if err != nil {
return fmt.Errorf("Fatal error reading config file: %s \n", err)
Expand Down Expand Up @@ -97,6 +100,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
Addr: viper.GetString("listen_addr"),
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
DerpMap: derpMap,
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),

EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),

Expand Down
98 changes: 58 additions & 40 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@ package headscale

import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"

mathrand "math/rand"

"golang.org/x/crypto/nacl/box"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/types/wgkey"
)

Expand Down Expand Up @@ -77,47 +71,71 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
return msg, nil
}

func (h *Headscale) getAvailableIP() (*net.IP, error) {
i := 0
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipPrefix := h.cfg.IPPrefix

usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}

// Get the first IP in our prefix
ip := ipPrefix.IP()

for {
ip, err := getRandomIP()
if err != nil {
return nil, err
if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
}

// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
// is typically called network or broadcast. Lets avoid them and continue
// to look when we get one of those traditionally reserved IPs.
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
continue
}

if ip.IsZero() &&
ip.IsLoopback() {

ip = ip.Next()
continue
}
m := Machine{}
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return ip, nil

if !containsIPs(usedIps, ip) {
return &ip, nil
}
i++
if i == 100 { // really random number
break

ip = ip.Next()
}
}

func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)

ips := make([]netaddr.IP, len(addresses))
for index, addr := range addresses {
if addr != "" {
ip, err := netaddr.ParseIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
}

ips[index] = ip
}
}
return nil, errors.New("Could not find an available IP address in 100.64.0.0/10")

return ips, nil
}

func getRandomIP() (*net.IP, error) {
mathrand.Seed(time.Now().Unix())
ipo, ipnet, err := net.ParseCIDR("100.64.0.0/10")
if err == nil {
ip := ipo.To4()
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet)
// fmt.Println("Final address is ", ip)
// fmt.Println("Broadcast address is ", ipb)
// fmt.Println("Network address is ", ipn)
r := mathrand.Uint32()
ipRaw := make([]byte, 4)
binary.LittleEndian.PutUint32(ipRaw, r)
// ipRaw[3] = 254
// fmt.Println("ipRaw is ", ipRaw)
for i, v := range ipRaw {
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i])
ip[i] = ip[i] + (v &^ ipnet.Mask[i])
// fmt.Println("IP After: ", ip[i])
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
// fmt.Println("FINAL IP: ", ip.String())
return &ip, nil
}

return nil, err
return false
}
155 changes: 155 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package headscale

import (
"gopkg.in/check.v1"
"inet.af/netaddr"
)

func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP()

c.Assert(err, check.IsNil)

expected := netaddr.MustParseIP("10.27.0.1")

c.Assert(ip.String(), check.Equals, expected.String())
}

func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)

pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)

m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)

ips, err := h.getUsedIPs()

c.Assert(err, check.IsNil)

expected := netaddr.MustParseIP("10.27.0.1")

c.Assert(ips[0], check.Equals, expected)

m1, err := h.GetMachineByID(0)
c.Assert(err, check.IsNil)

c.Assert(m1.IPAddress, check.Equals, expected.String())
}

func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)

for i := 1; i <= 350; i++ {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)

m := Machine{
ID: uint64(i),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
}

ips, err := h.getUsedIPs()

c.Assert(err, check.IsNil)

c.Assert(len(ips), check.Equals, 350)

c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))

// Check that we can read back the IPs
m1, err := h.GetMachineByID(1)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())

m50, err := h.GetMachineByID(50)
c.Assert(err, check.IsNil)
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())

expectedNextIP := netaddr.MustParseIP("10.27.1.97")
nextIP, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())

// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}

func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

expected := netaddr.MustParseIP("10.27.0.1")

c.Assert(ip.String(), check.Equals, expected.String())

n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)

pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)

m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)

ip2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)

c.Assert(ip2.String(), check.Equals, expected.String())
}