Skip to content

Commit

Permalink
Support automatic discovery of IPs from AWS metadata (#3)
Browse files Browse the repository at this point in the history
* refactor server creation

* feat: dynamic start/stop from aws

* only log errors starting server

* fix: various improvements
  • Loading branch information
jeffreymeng authored Sep 26, 2024
1 parent e988e2b commit b92e9f8
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 111 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Also, you need to set some kernel settings with Sysctl. Enable IPv4 forwarding,
```bash
# Applies until next reboot
sudo sysctl -w net.ipv4.ip_forward=1
sudo sysctl -w net.ipv4.conf.all.rp_filter=2
sudo sysctl -w net.ipv4.conf.all.rp_filter=2
```

To set up `vprox`, you'll need the private IPv4 address of the server connected to an Internet gateway (use the `ip addr` command), as well as a block of IPs to allocate to the WireGuard subnet between server and client. This has no particular meaning and can be arbitrarily chosen to not overlap with other subnets.
Expand Down
177 changes: 90 additions & 87 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@ import (
"context"
"errors"
"fmt"
"log"
"net/netip"
"os/signal"
"strconv"
"syscall"
"time"

"github.com/coreos/go-iptables/iptables"
"github.com/fatih/color"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
"golang.zx2c4.com/wireguard/wgctrl"

"github.com/modal-labs/vprox/lib"
)

var AWS_POLL_DURATION = 5000 * time.Millisecond // AWS is polled this frequently for new IPs

var ServerCmd = &cobra.Command{
Use: "server",
Short: "Start a VPN server, listening for new WireGuard peers",
Expand All @@ -45,19 +43,11 @@ func init() {

func runServer(cmd *cobra.Command, args []string) error {
cloud := serverCmdArgs.cloud
if cloud == "aws" {
client := lib.NewAwsMetadata()
interfaces, err := client.GetAddresses()
if err != nil {
return fmt.Errorf("failed to get AWS MAC addresses: %v", err)
}
fmt.Printf("%+v\n", interfaces)
return errors.New("todo: unimplemented")
} else if cloud != "" {
if cloud != "" && cloud != "aws" {
return fmt.Errorf("unknown value of --cloud: %v", cloud)
}

if len(serverCmdArgs.ip) == 0 {
if cloud == "" && len(serverCmdArgs.ip) == 0 {
return errors.New("missing required flag: --ip")
}
if len(serverCmdArgs.ip) > 1024 {
Expand All @@ -73,21 +63,22 @@ func runServer(cmd *cobra.Command, args []string) error {
}
wgBlock = wgBlock.Masked()

wgBlockPerIp := wgBlock.Bits()
wgBlockPerIp := uint(wgBlock.Bits())
if serverCmdArgs.wgBlockPerIp != "" {
if serverCmdArgs.wgBlockPerIp[0] != '/' {
return errors.New("--wg-block-per-ip must start with '/'")
}
wgBlockPerIp, err = strconv.Atoi(serverCmdArgs.wgBlockPerIp[1:])
parsedUint, err := strconv.ParseUint(serverCmdArgs.wgBlockPerIp[1:], 10, 0)
if err != nil {
return fmt.Errorf("failed to parse --wg-block-per-ip: %v", err)
}
wgBlockPerIp = uint(parsedUint)
}

if wgBlockPerIp > 30 || wgBlockPerIp < wgBlock.Bits() {
if wgBlockPerIp > 30 || wgBlockPerIp < uint(wgBlock.Bits()) {
return fmt.Errorf("invalid value of --wg-block-per-ip: %v", wgBlockPerIp)
}
wgBlockCount := 1 << (wgBlockPerIp - wgBlock.Bits())
wgBlockCount := 1 << (wgBlockPerIp - uint(wgBlock.Bits()))
if len(serverCmdArgs.ip) > wgBlockCount {
return fmt.Errorf(
"not enough IPs in --wg-block for %v --ip flags, please set --wg-block-per-ip",
Expand All @@ -104,93 +95,105 @@ func runServer(cmd *cobra.Command, args []string) error {
return err
}

// Make a shared WireGuard client.
wgClient, err := wgctrl.New()
if err != nil {
return fmt.Errorf("failed to initialize wgctrl: %v", err)
}
ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)

ipt, err := iptables.New(iptables.IPFamily(iptables.ProtocolIPv4), iptables.Timeout(5))
sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password)
if err != nil {
return fmt.Errorf("failed to initialize iptables: %v", err)
done()
return err
}

// Display the public key, just for information.
fmt.Printf("%s %s\n",
color.New(color.Bold).Sprint("server public key:"),
key.PublicKey().String())

ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer sm.Wait()
defer done()
g, ctx := errgroup.WithContext(ctx)

wgIp := nextIpBlock(wgBlock.Addr(), 32) // get the ".1" gateway IP address
for i, ipStr := range serverCmdArgs.ip {
ip, err := netip.ParseAddr(ipStr)
if err != nil || !ip.Is4() {
return fmt.Errorf("invalid IPv4 address: %v", ipStr)
if cloud == "aws" {
initialIps, err := pollAws(lib.NewAwsMetadata(), make(ipSet), sm)
if err != nil {
return err
}

srv := &lib.Server{
Key: key,
BindAddr: ip,
Password: password,
Index: uint16(i),
Ipt: ipt,
WgClient: wgClient,
WgCidr: netip.PrefixFrom(wgIp, wgBlockPerIp),
Ctx: ctx,
}
if err := srv.InitState(); err != nil {
return err
pollAwsLoop(ctx, sm, initialIps)
} else {
for _, ipStr := range serverCmdArgs.ip {
ip, err := netip.ParseAddr(ipStr)
if err != nil || !ip.Is4() {
return fmt.Errorf("invalid IPv4 address: %q", ipStr)
}
err = sm.Start(ip)
if err != nil {
return err
}
}
}

// Increment wgIp to be the next block.
wgIp = nextIpBlock(wgIp, uint(wgBlockPerIp))
return nil
}

g.Go(func() error {
if err := srv.StartWireguard(); err != nil {
return fmt.Errorf("failed to start WireGuard: %v", err)
}
defer srv.CleanupWireguard()
type ipSet map[netip.Addr]struct{}

if err := srv.StartIptables(); err != nil {
return fmt.Errorf("failed to start iptables: %v", err)
}
defer srv.CleanupIptables()
// parseIpSet parses the provided ipStrs and creates a map with
// the parsed IPs that can be used as a set.
func parseIpSet(ipStrs []string) (ipSet, error) {
m := make(ipSet)
for _, ipStr := range ipStrs {
ip, err := netip.ParseAddr(ipStr)
if err != nil || !ip.Is4() {
return nil, fmt.Errorf("invalid IPv4 address: %v", ipStr)
}
m[ip] = struct{}{}
}
return m, nil
}

if err := srv.ListenForHttps(); err != nil {
return fmt.Errorf("https server failed: %v", err)
}
return nil
})
// pollAws gets the current set of IP associations from AWS and starts/stops the
// server for those IPs.
func pollAws(awsClient *lib.AwsMetadata, currentIps ipSet, sm *lib.ServerManager) (ipSet, error) {
interfaces, err := awsClient.GetAddresses()

if err != nil {
return currentIps, fmt.Errorf("failed to get AWS MAC addresses: %v", err)
}

if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) {
return err
newIps, err := parseIpSet(interfaces[0].PrivateIps)
if err != nil {
return currentIps, err
}

return nil
}
for ip := range currentIps {
if _, ok := newIps[ip]; !ok {
sm.Stop(ip)
delete(currentIps, ip)
}
}

// Increments the given IP address by the given CIDR block size.
func nextIpBlock(ip netip.Addr, size uint) netip.Addr {
// Copy the IP address to avoid modifying the original.
ipBytes := ip.As4()

bits := 8 * uint(len(ipBytes))
if size > bits {
log.Panicf("nextIpBlock block size of %v is larger than ip bits %v", size, bits)
}
for size > 0 {
byteIndex := (size - 1) / 8
bitIndex := 7 - (size-1)%8
ipBytes[byteIndex] ^= 1 << bitIndex
if ipBytes[byteIndex]&(1<<bitIndex) > 0 {
break
for ip := range newIps {
if _, ok := currentIps[ip]; !ok {
if err := sm.Start(ip); err != nil {
return currentIps, fmt.Errorf("error starting new ip: %v", err)
}
currentIps[ip] = struct{}{}
}
size -= 1
}
return currentIps, nil
}

return netip.AddrFrom4(ipBytes)
// pollAwsLoop polls AWS in a blocking loop on an interval of AWS_POLL_DURATION
// until ctx is done.
func pollAwsLoop(ctx context.Context, sm *lib.ServerManager, initialIps ipSet) {
currentIps := initialIps
awsClient := lib.NewAwsMetadata()
ticker := time.NewTicker(AWS_POLL_DURATION)

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
var err error
currentIps, err = pollAws(awsClient, currentIps, sm)
if err != nil {
fmt.Printf("error during aws poll: %v", err)
}
}
}
}
20 changes: 0 additions & 20 deletions cmd/server_test.go

This file was deleted.

2 changes: 1 addition & 1 deletion lib/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (am *AwsMetadata) GetAddresses() ([]AwsInterface, error) {
return nil, err
}
publicIps := strings.Split(result, "\n")
privateIps := make([]string, len(publicIps))
privateIps := make([]string, 0, len(publicIps))

for _, ip := range publicIps {
privateIp, err := am.get(fmt.Sprintf("%s/%s/ipv4-associations/%s", prefix, mac, ip))
Expand Down
30 changes: 30 additions & 0 deletions lib/iputils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package lib

import (
"fmt"
"log"
"net"
"net/netip"
"sync"
Expand Down Expand Up @@ -101,3 +102,32 @@ func (ipa *IpAllocator) Free(addr netip.Addr) bool {
}
return false
}

// AfterCountIpBlock returns the result of incrementing an IP address by N CIDR
// counts.
func AfterCountIpBlock(ip netip.Addr, size uint, count uint) netip.Addr {
// Copy the IP address to avoid modifying the original.
ipBytes := ip.As4()

bits := 8 * uint(len(ipBytes))
if size > bits {
log.Panicf("block size of %v is larger than ip bits %v", size, bits)
}

// CIDR block size rounded up to the nearest multiple of 8
// 32->32, 31->32, 30->32, ..., 25->32, 24->24
tSize := 8 * ((size + 7) / 8)
tCount := count << (tSize - size)
for ; tSize > 0; tSize -= 8 {
c := tCount & 0xff // how much to add to the current byte
tCount = tCount >> 8
// 1.2.3.4/32 (byteIndex = 3) -> 1.2.3.5/32
// 1.2.3.4/24 (byteIndex = 2) -> 1.2.4.4/24
byteIndex := tSize/8 - 1
addWithCarry := uint(ipBytes[byteIndex]) + c
ipBytes[byteIndex] = byte(addWithCarry)
tCount += addWithCarry >> 8
}

return netip.AddrFrom4(ipBytes)
}
32 changes: 32 additions & 0 deletions lib/iputils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package lib

import (
"net/netip"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAfterOneIpBlock(t *testing.T) {
ip1 := netip.AddrFrom4([4]byte{192, 168, 1, 0})
ip2 := netip.AddrFrom4([4]byte{192, 168, 2, 0})
assert.Equal(t, AfterCountIpBlock(ip1, 24, 1), ip2, "next ip mismatch")

ip2 = netip.AddrFrom4([4]byte{192, 168, 1, 16})
assert.Equal(t, AfterCountIpBlock(ip1, 28, 1), ip2, "next ip mismatch")

ip2 = netip.AddrFrom4([4]byte{193, 168, 1, 0})
assert.Equal(t, AfterCountIpBlock(ip1, 8, 1), ip2, "next ip mismatch")
}

func TestAfterCountIpBlock(t *testing.T) {
ip1 := netip.AddrFrom4([4]byte{192, 168, 1, 0})
ip2 := netip.AddrFrom4([4]byte{192, 168, 6, 0})
assert.Equal(t, AfterCountIpBlock(ip1, 24, 5), ip2)

ip2 = netip.AddrFrom4([4]byte{192, 168, 1, 64})
assert.Equal(t, AfterCountIpBlock(ip1, 28, 4), ip2)

ip2 = netip.AddrFrom4([4]byte{192, 168, 2, 128})
assert.Equal(t, AfterCountIpBlock(ip1, 25, 3), ip2)
}
9 changes: 7 additions & 2 deletions lib/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,12 @@ func (srv *Server) StartIptables() error {
}

func (srv *Server) CleanupIptables() {
srv.iptablesInputFwmarkRule(false)
srv.iptablesSnatRule(false)
if err := srv.iptablesInputFwmarkRule(false); err != nil {
log.Printf("warning: error cleaning up IP tables: failed to add fwmark rule: %v\n", err)
}
if err := srv.iptablesSnatRule(false); err != nil {
log.Printf("warning: error cleaning up IP tables: failed to add SNAT rule: %v\n", err)
}
}

func (srv *Server) removeIdlePeersLoop() {
Expand Down Expand Up @@ -446,6 +450,7 @@ func (srv *Server) ListenForHttps() error {

select {
case <-srv.Ctx.Done():
log.Printf("server no longer listening on %v:443\n", srv.BindAddr)
return httpServer.Shutdown(srv.Ctx)
case err = <-errCh:
return err
Expand Down
Loading

0 comments on commit b92e9f8

Please sign in to comment.