diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 11619bbeb..9f2987bc3 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -79,6 +79,7 @@ type BlockingResolver struct { blockHandler blockHandler whitelistOnlyGroups map[string]bool status *status + clientGroupsBlock map[string][]string } // NewBlockingResolver returns a new configured instance of the resolver @@ -103,6 +104,19 @@ func NewBlockingResolver(cfg config.BlockingConfig) (ChainedResolver, error) { return nil, multierror.Prefix(err, "blocking resolver: ") } + cgb := make(map[string][]string) + + for identifier, cfgGroups := range cfg.ClientGroupsBlock { + for _, ipart := range strings.Split(identifier, ",") { + existingGroups, found := cgb[ipart] + if found { + cgb[ipart] = append(existingGroups, cfgGroups...) + } else { + cgb[ipart] = cfgGroups + } + } + } + res := &BlockingResolver{ blockHandler: blockHandler, cfg: cfg, @@ -113,6 +127,7 @@ func NewBlockingResolver(cfg config.BlockingConfig) (ChainedResolver, error) { enabled: true, enableTimer: time.NewTimer(0), }, + clientGroupsBlock: cgb, } return res, nil @@ -365,7 +380,7 @@ func (r *BlockingResolver) groupsToCheckForClient(request *model.Request) []stri var groups []string // try client names for _, cName := range request.ClientNames { - for blockGroup, groupsByName := range r.cfg.ClientGroupsBlock { + for blockGroup, groupsByName := range r.clientGroupsBlock { if util.ClientNameMatchesGroupName(blockGroup, cName) { groups = append(groups, groupsByName...) } @@ -373,14 +388,14 @@ func (r *BlockingResolver) groupsToCheckForClient(request *model.Request) []stri } // try IP - groupsByIP, found := r.cfg.ClientGroupsBlock[request.ClientIP.String()] + groupsByIP, found := r.clientGroupsBlock[request.ClientIP.String()] if found { groups = append(groups, groupsByIP...) } // try CIDR - for cidr, groupsByCidr := range r.cfg.ClientGroupsBlock { + for cidr, groupsByCidr := range r.clientGroupsBlock { if util.CidrContainsIP(cidr, request.ClientIP) { groups = append(groups, groupsByCidr...) } @@ -388,7 +403,7 @@ func (r *BlockingResolver) groupsToCheckForClient(request *model.Request) []stri if len(groups) == 0 { // return default - groups = r.cfg.ClientGroupsBlock["default"] + groups = r.clientGroupsBlock["default"] } var result []string diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index e66ad7840..459a01a48 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -118,12 +118,14 @@ badcnamedomain.com`) "defaultGroup": {defaultGroupFile.Name()}, }, ClientGroupsBlock: map[string][]string{ - "client1": {"gr1"}, - "192.168.178.55": {"gr1"}, - "altName": {"gr2"}, - "10.43.8.67/28": {"gr1"}, - "wildcard[0-9]*": {"gr1"}, - "default": {"defaultGroup"}, + "client1": {"gr1"}, + "client2,client3": {"gr1"}, + "client3": {"gr2"}, + "192.168.178.55": {"gr1"}, + "altName": {"gr2"}, + "10.43.8.67/28": {"gr1"}, + "wildcard[0-9]*": {"gr1"}, + "default": {"defaultGroup"}, }, BlockType: "ZeroIP", } @@ -134,11 +136,26 @@ badcnamedomain.com`) }) When("client name is defined in client groups block", func() { - It("should block the A query if domain is on the black list", func() { + It("should block the A query if domain is on the black list (single)", func() { resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "client1")) Expect(resp.Res.Answer).Should(BeDNSRecord("domain1.com.", dns.TypeA, 21600, "0.0.0.0")) }) + It("should block the A query if domain is on the black list (multipart 1)", func() { + resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "client2")) + + Expect(resp.Res.Answer).Should(BeDNSRecord("domain1.com.", dns.TypeA, 21600, "0.0.0.0")) + }) + It("should block the A query if domain is on the black list (multipart 2)", func() { + resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "client3")) + + Expect(resp.Res.Answer).Should(BeDNSRecord("domain1.com.", dns.TypeA, 21600, "0.0.0.0")) + }) + It("should block the A query if domain is on the black list (merged)", func() { + resp, err = sut.Resolve(newRequestWithClient("blocked2.com.", dns.TypeA, "1.2.1.2", "client3")) + + Expect(resp.Res.Answer).Should(BeDNSRecord("blocked2.com.", dns.TypeA, 21600, "0.0.0.0")) + }) It("should block the AAAA query if domain is on the black list", func() { resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeAAAA, "1.2.1.2", "client1"))