Skip to content

Commit

Permalink
all: imp err handling
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jun 20, 2023
1 parent 8543868 commit 29cfc7a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 26 deletions.
15 changes: 11 additions & 4 deletions internal/filtering/blocked.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package filtering

import (
"encoding/json"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -63,11 +64,17 @@ func (s *BlockedServices) Clone() (c *BlockedServices) {
}
}

// BlockedSvcKnown returns true if a blocked service ID is known.
func BlockedSvcKnown(s string) (ok bool) {
_, ok = serviceRules[s]
// Validate returns an error if blocked services contain unknown service ID. s
// must not be nil.
func (s *BlockedServices) Validate() (err error) {
for _, id := range s.IDs {
_, ok := serviceRules[id]
if !ok {
return fmt.Errorf("unknown blocked-service %q", id)
}
}

return ok
return nil
}

// ApplyBlockedServices - set blocked services settings for this DNS request
Expand Down
12 changes: 3 additions & 9 deletions internal/filtering/filtering.go
Original file line number Diff line number Diff line change
Expand Up @@ -988,17 +988,11 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
}

if d.BlockedServices != nil {
bsvcs := []string{}
for _, s := range d.BlockedServices.IDs {
if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service %q", s)
err = d.BlockedServices.Validate()

continue
}

bsvcs = append(bsvcs, s)
if err != nil {
return nil, fmt.Errorf("filtering: %w", err)
}
d.BlockedServices.IDs = bsvcs
}

if blockFilters != nil {
Expand Down
28 changes: 22 additions & 6 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (clients *clientsContainer) Init(
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
filteringConf *filtering.Config,
) {
) (err error) {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
Expand All @@ -91,13 +91,17 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.addFromConfig(objects, filteringConf)
err = clients.addFromConfig(objects, filteringConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}

clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)

if clients.testing {
return
return nil
}

clients.updateFromDHCP(true)
Expand All @@ -108,6 +112,8 @@ func (clients *clientsContainer) Init(
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}

return nil
}

func (clients *clientsContainer) handleHostsUpdates() {
Expand Down Expand Up @@ -168,7 +174,10 @@ type clientObject struct {

// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
func (clients *clientsContainer) addFromConfig(
objects []*clientObject,
filteringConf *filtering.Config,
) (err error) {
for _, o := range objects {
cli := &Client{
Name: o.Name,
Expand All @@ -189,7 +198,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}

err := cli.setSafeSearch(
err = cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
Expand All @@ -201,6 +210,11 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
}
}

err = o.BlockedServices.Validate()
if err != nil {
return fmt.Errorf("clients: %w", err)
}

cli.BlockedServices = o.BlockedServices.Clone()

for _, t := range o.Tags {
Expand All @@ -213,11 +227,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin

slices.Sort(cli.Tags)

_, err := clients.Add(cli)
_, err = clients.Add(cli)
if err != nil {
log.Error("clients: adding clients %s: %s", cli.Name, err)
}
}

return nil
}

// forConfig returns all currently known persistent clients as objects for the
Expand Down
13 changes: 7 additions & 6 deletions internal/home/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ import (

// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer() (c *clientsContainer) {
func newClientsContainer(t *testing.T) (c *clientsContainer) {
c = &clientsContainer{
testing: true,
}

c.Init(nil, nil, nil, nil, &filtering.Config{})
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
require.NoError(t, err)

return c
}

func TestClients(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)

t.Run("add_success", func(t *testing.T) {
var (
Expand Down Expand Up @@ -198,7 +199,7 @@ func TestClients(t *testing.T) {
}

func TestClientsWHOIS(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
whois := &RuntimeClientWHOISInfo{
Country: "AU",
Orgname: "Example Org",
Expand Down Expand Up @@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
}

func TestClientsAddExisting(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)

t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
Expand Down Expand Up @@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
}

func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)

// Add client with upstreams.
ok, err := clients.Add(&Client{
Expand Down
6 changes: 5 additions & 1 deletion internal/home/home.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,17 @@ func initContextClients() (err error) {
arpdb = aghnet.NewARPDB()
}

Context.clients.Init(
err = Context.clients.Init(
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
arpdb,
config.DNS.DnsfilterConf,
)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}

return nil
}
Expand Down

0 comments on commit 29cfc7a

Please sign in to comment.