Skip to content

Commit

Permalink
#202: WhitelistOnly Fix for multiple entries (#199)
Browse files Browse the repository at this point in the history
* Update blocking_resolver.go

Adjusted WhitelistOnly

* added test

* fixed golint issues

Co-authored-by: Dimitri Herzog <dimitri.herzog@gmail.com>
  • Loading branch information
c-f and 0xERR0R authored May 5, 2021
1 parent 1d511a3 commit dd69a3e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
23 changes: 16 additions & 7 deletions resolver/blocking_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"blocky/util"
"fmt"
"net"
"reflect"
"sort"
"strings"
"time"
Expand Down Expand Up @@ -65,7 +64,7 @@ type BlockingResolver struct {
whitelistMatcher *lists.ListCache
cfg config.BlockingConfig
blockHandler blockHandler
whitelistOnlyGroups []string
whitelistOnlyGroups map[string]bool
status *status
}

Expand Down Expand Up @@ -177,17 +176,17 @@ func (r *BlockingResolver) BlockingStatus() api.BlockingStatus {
}

// returns groups, which have only whitelist entries
func determineWhitelistOnlyGroups(cfg *config.BlockingConfig) (result []string) {
func determineWhitelistOnlyGroups(cfg *config.BlockingConfig) (result map[string]bool) {
result = make(map[string]bool)

for g, links := range cfg.WhiteLists {
if len(links) > 0 {
if _, found := cfg.BlackLists[g]; !found {
result = append(result, g)
result[g] = true
}
}
}

sort.Strings(result)

return
}

Expand Down Expand Up @@ -230,10 +229,20 @@ func (r *BlockingResolver) Configuration() (result []string) {
return
}

func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool {
for _, group := range groupsToCheck {
if _, found := r.whitelistOnlyGroups[group]; found {
return true
}
}

return false
}

func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
request *Request, logger *logrus.Entry) (*Response, error) {
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
whitelistOnlyAllowed := reflect.DeepEqual(groupsToCheck, r.whitelistOnlyGroups)
whitelistOnlyAllowed := r.hasWhiteListOnlyAllowed(groupsToCheck)

for _, question := range request.Req.Question {
domain := util.ExtractDomain(question)
Expand Down
41 changes: 38 additions & 3 deletions resolver/blocking_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,19 @@ badcnamedomain.com`)
When("Only whitelist is defined", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
WhiteLists: map[string][]string{"gr1": {group1File.Name()}},
WhiteLists: map[string][]string{
"gr1": {group1File.Name()},
"gr2": {group2File.Name()},
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
"default": {"gr1"},
"one-client": {"gr1"},
"two-client": {"gr2"},
"all-client": {"gr1", "gr2"},
},
}
})
It("should block everything else except domains on the white list", func() {
It("should block everything else except domains on the white list with default group", func() {
By("querying domain on the whitelist", func() {
resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "unknown"))

Expand All @@ -379,6 +385,35 @@ badcnamedomain.com`)
Expect(resp.Reason).Should(Equal("BLOCKED (WHITELIST ONLY)"))
})
})
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist", func() {
resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "one-client"))

// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})

By("querying another domain, which is not on the whitelist", func() {
resp, err = sut.Resolve(newRequestWithClient("blocked2.com.", dns.TypeA, "1.2.1.2", "one-client"))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Reason).Should(Equal("BLOCKED (WHITELIST ONLY)"))
})
})
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist group 1", func() {
resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "all-client"))

// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})

By("querying another domain, which is in the whitelist group 1", func() {
resp, err = sut.Resolve(newRequestWithClient("blocked2.com.", dns.TypeA, "1.2.1.2", "all-client"))
Expect(m.Calls).Should(HaveLen(2))
})
})
})

When("IP address is on black and white list", func() {
Expand Down

0 comments on commit dd69a3e

Please sign in to comment.