Skip to content

Commit

Permalink
dhcpsvc: add constructor, validations, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Sep 13, 2023
1 parent f2533ed commit c9ef290
Show file tree
Hide file tree
Showing 6 changed files with 768 additions and 66 deletions.
143 changes: 79 additions & 64 deletions internal/dhcpsvc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ import (
"net/netip"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/gopacket/layers"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

// InterfaceConfig is the configuration of a single DHCP interface.
type InterfaceConfig struct {
// IPv4 is the configuration of DHCP protocol for IPv4.
IPv4 *IPv4Config

// IPv6 is the configuration of DHCP protocol for IPv6.
IPv6 *IPv6Config
}

// Config is the configuration for the DHCP service.
type Config struct {
// Interfaces stores configurations of DHCP server specific for the network
Expand All @@ -29,13 +37,46 @@ type Config struct {
Enabled bool
}

// InterfaceConfig is the configuration of a single DHCP interface.
type InterfaceConfig struct {
// IPv4 is the configuration of DHCP protocol for IPv4.
IPv4 *IPv4Config
// Validate returns an error in conf if any.
func (conf *Config) Validate() (err error) {
switch {
case conf == nil:
return errNilConfig
case !conf.Enabled:
return nil
case conf.ICMPTimeout < 0:
return fmt.Errorf("icmp timeout %s must be non-negative", conf.ICMPTimeout)
}

// IPv6 is the configuration of DHCP protocol for IPv6.
IPv6 *IPv6Config
err = netutil.ValidateDomainName(conf.LocalDomainName)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

if len(conf.Interfaces) == 0 {
return errNoInterfaces
}

ifaces := maps.Keys(conf.Interfaces)
slices.Sort(ifaces)

for _, iface := range ifaces {
ifaceConf := conf.Interfaces[iface]
if ifaceConf == nil {
return fmt.Errorf("interface %q: %w", iface, errNilConfig)
}

if err = ifaceConf.IPv4.validate(); err != nil {
return fmt.Errorf("interface %q: ipv4: %w", iface, err)
}

if err = ifaceConf.IPv6.validate(); err != nil {
return fmt.Errorf("interface %q: ipv6: %w", iface, err)
}
}

return nil
}

// IPv4Config is the interface-specific configuration for DHCPv4.
Expand Down Expand Up @@ -66,6 +107,28 @@ type IPv4Config struct {
Enabled bool
}

// validate returns an error in conf if any.
func (conf *IPv4Config) validate() (err error) {
switch {
case conf == nil:
return errNilConfig
case !conf.Enabled:
return nil
case !conf.GatewayIP.Is4():
return fmt.Errorf("gateway ip %s should be a valid ipv4", conf.GatewayIP)
case !conf.SubnetMask.Is4():
return fmt.Errorf("subnet mask %s should be a valid ipv4 cidr", conf.SubnetMask)
case !conf.RangeStart.Is4():
return fmt.Errorf("range start %s should be a valid ipv4", conf.RangeStart)
case !conf.RangeEnd.Is4():
return fmt.Errorf("range end %s should be a valid ipv4", conf.RangeEnd)
case conf.LeaseDuration <= 0:
return fmt.Errorf("lease duration %s must be positive", conf.LeaseDuration)
default:
return nil
}
}

// IPv6Config is the interface-specific configuration for DHCPv6.
type IPv6Config struct {
// RangeStart is the first address in the range to assign to DHCP clients.
Expand All @@ -90,66 +153,18 @@ type IPv6Config struct {
Enabled bool
}

// TODO(e.burkov): !! doc
const ErrNilConfig errors.Error = "config is nil"

func (conf *Config) Validate() (err error) {
// validate returns an error in conf if any.
func (conf *IPv6Config) validate() (err error) {
switch {
case conf == nil:
return ErrNilConfig
return errNilConfig
case !conf.Enabled:
return nil
case conf.ICMPTimeout < 0:
return fmt.Errorf("icmp timeout %s must be non-negative", conf.ICMPTimeout)
}

err = netutil.ValidateDomainName(conf.LocalDomainName)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

ifaces := maps.Keys(conf.Interfaces)
slices.Sort(ifaces)

return errors.Join(
errors.Annotate(conf.validateV4(ifaces), "validating v4: %w"),
errors.Annotate(conf.validateV6(ifaces), "validating v6: %w"),
)
}

func (conf *Config) validateV4(ifaces []string) (err error) {
for _, iface := range ifaces {
ifaceConf := conf.Interfaces[iface]
if ifaceConf == nil {
return ErrNilConfig
}

v4Conf := ifaceConf.IPv4
switch {
case !v4Conf.Enabled:
continue
case !v4Conf.GatewayIP.Is4():
return fmt.Errorf("interface %q: gateway ip should be a valid ipv4", iface)
case !v4Conf.SubnetMask.Is4():
return fmt.Errorf("interface %q: subnet mask should be a valid ipv4 cidr", iface)
case !v4Conf.RangeStart.Is4():
return fmt.Errorf("interface %q: range start should be a valid ipv4", iface)
case !v4Conf.RangeEnd.Is4():
return fmt.Errorf("interface %q: range end should be a valid ipv4", iface)
}

c.ipRange, err = newIPRange(rangeStart.AsSlice(), rangeEnd.AsSlice())
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// an annotation deferred already.
return err
}
case !conf.RangeStart.Is6():
return fmt.Errorf("range start %s should be a valid ipv6", conf.RangeStart)
case conf.LeaseDuration <= 0:
return fmt.Errorf("lease duration %s must be positive", conf.LeaseDuration)
default:
return nil
}

return nil
}

func (conf *Config) validateV6(ifaces []string) (err error) {
return nil
}
11 changes: 11 additions & 0 deletions internal/dhcpsvc/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dhcpsvc

import "github.com/AdguardTeam/golibs/errors"

const (
// errNilConfig is returned when a nil config met.
errNilConfig errors.Error = "config is nil"

// errNoInterfaces is returned when no interfaces found in configuration.
errNoInterfaces errors.Error = "no interfaces specified"
)
124 changes: 124 additions & 0 deletions internal/dhcpsvc/iprange.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package dhcpsvc

import (
"fmt"
"math"
"math/big"
"net"
"net/netip"

"github.com/AdguardTeam/golibs/errors"
)

// ipRange is an inclusive range of IP addresses. A nil range is a range that
// doesn't contain any IP addresses.
//
// It is safe for concurrent use.
//
// TODO(a.garipov): Perhaps create an optimized version with uint32 for IPv4
// ranges? Or use one of uint128 packages?
type ipRange struct {
start *big.Int
end *big.Int
}

// maxRangeLen is the maximum IP range length. The bitsets used in servers only
// accept uints, which can have the size of 32 bit.
const maxRangeLen = math.MaxUint32

// newIPRange creates a new IP address range. start must be less than end. The
// resulting range must not be greater than maxRangeLen.
func newIPRange(start, end netip.Addr) (r *ipRange, err error) {
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()

if !start.Less(end) {
return nil, fmt.Errorf("start is greater than or equal to end")
}

// Make sure that both are 16 bytes long to simplify handling in
// methods.
startData, endData := start.As16(), end.As16()

startInt := (&big.Int{}).SetBytes(startData[:])
endInt := (&big.Int{}).SetBytes(endData[:])
diff := (&big.Int{}).Sub(endInt, startInt)

if !diff.IsUint64() || diff.Uint64() > maxRangeLen {
return nil, fmt.Errorf("range is too large")
}

return &ipRange{
start: startInt,
end: endInt,
}, nil
}

// contains returns true if r contains ip.
func (r *ipRange) contains(ip netip.Addr) (ok bool) {
if r == nil {
return false
}

ipData := ip.As16()

return r.containsInt((&big.Int{}).SetBytes(ipData[:]))
}

// containsInt returns true if r contains ipInt. For internal use only.
func (r *ipRange) containsInt(ipInt *big.Int) (ok bool) {
return ipInt.Cmp(r.start) >= 0 && ipInt.Cmp(r.end) <= 0
}

// ipPredicate is a function that is called on every IP address in
// (*ipRange).find. ip is given in the 16-byte form.
type ipPredicate func(ip netip.Addr) (ok bool)

// find finds the first IP address in r for which p returns true. ip is in the
// 16-byte form. It returns an empty [netip.Addr] if no addresses satisfy p.
func (r *ipRange) find(p ipPredicate) (ip netip.Addr) {
if r == nil {
return netip.Addr{}
}

_1 := big.NewInt(1)
var ipData [16]byte
for i := (&big.Int{}).Set(r.start); i.Cmp(r.end) <= 0; i.Add(i, _1) {
i.FillBytes(ipData[:])
ip = netip.AddrFrom16(ipData)
if p(ip) {
return ip
}
}

return netip.Addr{}
}

// offset returns the offset of ip from the beginning of r. It returns 0 and
// false if ip is not in r.
func (r *ipRange) offset(ip netip.Addr) (offset uint64, ok bool) {
if r == nil {
return 0, false
}

ipData := ip.As16()
ipInt := (&big.Int{}).SetBytes(ipData[:])
if !r.containsInt(ipInt) {
return 0, false
}

offsetInt := (&big.Int{}).Sub(ipInt, r.start)

// Assume that the range was checked against maxRangeLen during
// construction.
return offsetInt.Uint64(), true
}

// String implements the fmt.Stringer interface for *ipRange.
func (r *ipRange) String() (s string) {
start, end := [16]byte{}, [16]byte{}

r.start.FillBytes(start[:])
r.end.FillBytes(end[:])

return fmt.Sprintf("%s-%s", net.IP(start[:]), net.IP(end[:]))
}
Loading

0 comments on commit c9ef290

Please sign in to comment.