Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Aug 21, 2024
1 parent 03c69ab commit 37ccae4
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 46 deletions.
2 changes: 1 addition & 1 deletion internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ func (s *Server) prepareInternalDNS() (err error) {
}

ipsetLogger := s.logger.With(slogutil.KeyPrefix, "ipset")
s.ipset, err = newIpsetHandler(ipsetLogger, ipsetList)
s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
Expand Down
54 changes: 30 additions & 24 deletions internal/dnsforward/ipset.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ type ipsetHandler struct {

// newIpsetHandler returns a new initialized [ipsetHandler]. It is not safe for
// concurrent use. c is always non-nil for [Server.Close].
func newIpsetHandler(logger *slog.Logger, ipsetList []string) (c *ipsetHandler, err error) {
c = &ipsetHandler{
func newIpsetHandler(
ctx context.Context,
logger *slog.Logger,
ipsetList []string,
) (h *ipsetHandler, err error) {
h = &ipsetHandler{
logger: logger,
}
c.ipsetMgr, err = ipset.NewManager(&ipset.Config{
Logger: logger,
IpsetList: ipsetList,
})
conf := &ipset.Config{
Logger: logger,
Lines: ipsetList,
}
h.ipsetMgr, err = ipset.NewManager(ctx, conf)
if errors.Is(err, os.ErrInvalid) ||
errors.Is(err, os.ErrPermission) ||
errors.Is(err, errors.ErrUnsupported) {
Expand All @@ -41,26 +46,27 @@ func newIpsetHandler(logger *slog.Logger, ipsetList []string) (c *ipsetHandler,
//
// TODO(a.garipov): The Snap problem can probably be solved if we add
// the netlink-connector interface plug.
logger.Warn("cannot initialize", slogutil.KeyError, err)
logger.WarnContext(ctx, "cannot initialize", slogutil.KeyError, err)

return c, nil
return h, nil
} else if err != nil {
return c, fmt.Errorf("initializing ipset: %w", err)
return h, fmt.Errorf("initializing ipset: %w", err)
}

return c, nil
return h, nil
}

// close closes the Linux Netfilter connections.
func (c *ipsetHandler) close() (err error) {
if c.ipsetMgr != nil {
return c.ipsetMgr.Close()
func (h *ipsetHandler) close() (err error) {
if h.ipsetMgr != nil {
return h.ipsetMgr.Close()
}

return nil
}

func (c *ipsetHandler) dctxIsfilled(dctx *dnsContext) (ok bool) {
// dctxIsFilled returns true if dctx has enough information to process.
func dctxIsFilled(dctx *dnsContext) (ok bool) {
return dctx != nil &&
dctx.responseFromUpstream &&
dctx.proxyCtx != nil &&
Expand All @@ -71,8 +77,8 @@ func (c *ipsetHandler) dctxIsfilled(dctx *dnsContext) (ok bool) {

// skipIpsetProcessing returns true when the ipset processing can be skipped for
// this request.
func (c *ipsetHandler) skipIpsetProcessing(dctx *dnsContext) (ok bool) {
if c == nil || c.ipsetMgr == nil || !c.dctxIsfilled(dctx) {
func (h *ipsetHandler) skipIpsetProcessing(dctx *dnsContext) (ok bool) {
if h == nil || h.ipsetMgr == nil || !dctxIsFilled(dctx) {
return true
}

Expand Down Expand Up @@ -114,13 +120,13 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
}

// process adds the resolved IP addresses to the domain's ipsets, if any.
func (c *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
c.logger.Debug("started processing")
defer c.logger.Debug("finished processing")

func (h *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Use passed context.
ctx := context.TODO()
h.logger.DebugContext(ctx, "started processing")
defer h.logger.DebugContext(ctx, "finished processing")

if c.skipIpsetProcessing(dctx) {
if h.skipIpsetProcessing(dctx) {
return resultCodeSuccess
}

Expand All @@ -130,15 +136,15 @@ func (c *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
host = strings.ToLower(host)

ip4s, ip6s := ipsFromAnswer(dctx.proxyCtx.Res.Answer)
n, err := c.ipsetMgr.Add(ctx, host, ip4s, ip6s)
n, err := h.ipsetMgr.Add(ctx, host, ip4s, ip6s)
if err != nil {
// Consider ipset errors non-critical to the request.
c.logger.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
h.logger.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)

return resultCodeSuccess
}

c.logger.DebugContext(ctx, "added new ipset entries", "num", n)
h.logger.DebugContext(ctx, "added new ipset entries", "num", n)

return resultCodeSuccess
}
14 changes: 7 additions & 7 deletions internal/ipset/ipset.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@ type Config struct {
// not be nil.
Logger *slog.Logger

// IpsetList is the ipset configuration with the following syntax:
// Lines is the ipset configuration with the following syntax:
//
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
//
// IpsetList must not contain any blank lines or comments.
IpsetList []string
// Lines must not contain any blank lines or comments.
Lines []string
}

// NewManager returns a new ipset manager. IPv4 addresses are added to an ipset
// with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must exist.
//
// If conf.IpsetList is empty, mgr and err are nil. The error's chain contains
// If conf.Lines is empty, mgr and err are nil. The error's chain contains
// [errors.ErrUnsupported] if current OS is not supported.
func NewManager(conf *Config) (mgr Manager, err error) {
if len(conf.IpsetList) == 0 {
func NewManager(ctx context.Context, conf *Config) (mgr Manager, err error) {
if len(conf.Lines) == 0 {
return nil, nil
}

return newManager(conf)
return newManager(ctx, conf)
}
25 changes: 15 additions & 10 deletions internal/ipset/ipset_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ import (
// resolved IP addresses.

// newManager returns a new Linux ipset manager.
func newManager(conf *Config) (set Manager, err error) {
return newManagerWithDialer(conf, defaultDial)
func newManager(ctx context.Context, conf *Config) (set Manager, err error) {
return newManagerWithDialer(ctx, conf, defaultDial)
}

// defaultDial is the default netfilter dialing function.
Expand Down Expand Up @@ -258,7 +258,7 @@ func parseIpsetConfigLine(confStr string) (hosts, ipsetNames []string, err error

// parseIpsetConfig parses the ipset configuration and stores ipsets. It
// returns an error if the configuration can't be used.
func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) {
func (m *manager) parseIpsetConfig(ctx context.Context, ipsetConf []string) (err error) {
// The family doesn't seem to matter when we use a header query, so query
// only the IPv4 one.
//
Expand All @@ -282,7 +282,7 @@ func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) {
}

var ipsets []props
ipsets, err = m.ipsets(ipsetNames, currentlyKnown)
ipsets, err = m.ipsets(ctx, ipsetNames, currentlyKnown)
if err != nil {
return fmt.Errorf("getting ipsets from config line at idx %d: %w", i, err)
}
Expand Down Expand Up @@ -332,15 +332,20 @@ func (m *manager) ipsetProps(name string) (p props, err error) {

// ipsets returns ipset properties of currently known ipsets. It also makes an
// additional ipset header data query if needed.
func (m *manager) ipsets(names []string, currentlyKnown map[string]props) (sets []props, err error) {
func (m *manager) ipsets(
ctx context.Context,
names []string,
currentlyKnown map[string]props,
) (sets []props, err error) {
for _, n := range names {
p, ok := currentlyKnown[n]
if !ok {
return nil, fmt.Errorf("unknown ipset %q", n)
}

if p.family != netfilter.ProtoIPv4 && p.family != netfilter.ProtoIPv6 {
m.logger.Debug(
m.logger.DebugContext(
ctx,
"got unexpected ipset family while getting set properties",
"set_name", p.name,
"set_type", p.typeName,
Expand All @@ -362,7 +367,7 @@ func (m *manager) ipsets(names []string, currentlyKnown map[string]props) (sets

// newManagerWithDialer returns a new Linux ipset manager using the provided
// dialer.
func newManagerWithDialer(conf *Config, dial dialer) (mgr Manager, err error) {
func newManagerWithDialer(ctx context.Context, conf *Config, dial dialer) (mgr Manager, err error) {
defer func() { err = errors.Annotate(err, "ipset: %w") }()

m := &manager{
Expand All @@ -383,20 +388,20 @@ func newManagerWithDialer(conf *Config, dial dialer) (mgr Manager, err error) {
if errors.Is(err, unix.EPROTONOSUPPORT) {
// The implementation doesn't support this protocol version. Just
// issue a warning.
m.logger.Warn("dialing netfilter", slogutil.KeyError, err)
m.logger.WarnContext(ctx, "dialing netfilter", slogutil.KeyError, err)

return nil, nil
}

return nil, fmt.Errorf("dialing netfilter: %w", err)
}

err = m.parseIpsetConfig(conf.IpsetList)
err = m.parseIpsetConfig(ctx, conf.Lines)
if err != nil {
return nil, fmt.Errorf("getting ipsets: %w", err)
}

m.logger.Debug("initialized")
m.logger.DebugContext(ctx, "initialized")

return m, nil
}
Expand Down
6 changes: 3 additions & 3 deletions internal/ipset/ipset_linux_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ func TestManager_Add(t *testing.T) {
}

conf := &Config{
Logger: slogutil.NewDiscardLogger(),
IpsetList: ipsetList,
Logger: slogutil.NewDiscardLogger(),
Lines: ipsetList,
}
m, err := newManagerWithDialer(conf, fakeDial)
m, err := newManagerWithDialer(testutil.ContextWithTimeout(t, testTimeout), conf, fakeDial)
require.NoError(t, err)

ip4 := net.IP{1, 2, 3, 4}
Expand Down
4 changes: 3 additions & 1 deletion internal/ipset/ipset_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package ipset

import (
"context"

"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)

func newManager(_ *Config) (mgr Manager, err error) {
func newManager(_ context.Context, _ *Config) (mgr Manager, err error) {
return nil, aghos.Unsupported("ipset")
}

0 comments on commit 37ccae4

Please sign in to comment.