From 911d341310ff7cc973a285355eacbbd2428bc039 Mon Sep 17 00:00:00 2001 From: denis-tingajkin Date: Wed, 8 Apr 2020 12:04:44 +0700 Subject: [PATCH] improve except option Signed-off-by: denis-tingajkin --- README.md | 1 + domain.go | 119 +++++++++++++++++++++++++++++++++++++ domain_test.go | 157 +++++++++++++++++++++++++++++++++++++++++++++++++ fanout.go | 43 +++++--------- fanout_test.go | 26 ++++++++ setup.go | 23 +++++++- setup_test.go | 6 +- 7 files changed, 344 insertions(+), 31 deletions(-) create mode 100644 domain.go create mode 100644 domain_test.go diff --git a/README.md b/README.md index 53d3378..e74dc5c 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Each incoming DNS query that hits the CoreDNS fanout plugin will be replicated i * `worker-count` is the number of parallel queries per request. By default equals to count of IP list. Use this only for reducing parallel queries per request. * `network` is a specific network protocol. Could be `tcp`, `udp`, `tcp-tls`. * `except` is a list is a space-separated list of domains to exclude from proxying. +* `except-file` is the path to file with line-separated list of domains to exclude from proxying. * `attempt-count` is the number of attempts to connect to upstream servers that are needed before considering an upstream to be down. If 0, the upstream will never be marked as down and request will be finished by `timeout`. Default is `3`. * `timeout` is the timeout of request. After this period, attempts to receive a response from the upstream servers will be stopped. Default is `30s`. ## Metrics diff --git a/domain.go b/domain.go new file mode 100644 index 0000000..673e76a --- /dev/null +++ b/domain.go @@ -0,0 +1,119 @@ +// Copyright (c) 2020 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fanout + +import ( + "strings" +) + +// Domain represents DNS domain name +type Domain interface { + Get(string) Domain + AddString(string) + Add(string, Domain) + Contains(string) bool + IsFinal() bool + Finish() +} + +type domain struct { + children map[string]Domain + end bool +} + +// Finish marks current domain as last in the chain +func (l *domain) Finish() { + l.end = true +} + +// Add adds domain by name +func (l *domain) Add(n string, d Domain) { + l.children[n] = d +} + +// IsFinal returns true if this domain is last in the chain +func (l *domain) IsFinal() bool { + return l.end +} + +// Contains parses string and check is domains contains +func (l *domain) Contains(s string) bool { + end := len(s) + var curr Domain = l + for start := strings.LastIndex(s, "."); start != -1; start = strings.LastIndex(s[:start], ".") { + var k string + if start == end-1 { + k = "." + } else { + k = s[start+1 : end] + } + end = start + curr = curr.Get(k) + if curr == nil { + return false + } + if curr.IsFinal() { + return true + } + } + curr = curr.Get(s[:end]) + if curr == nil { + return false + } + return curr.IsFinal() +} + +// AddString parses string and adds child domains +func (l *domain) AddString(s string) { + end := len(s) + var curr = Domain(l) + for start := strings.LastIndex(s, "."); start != -1; start = strings.LastIndex(s[:start], ".") { + var k string + if start == end-1 { + k = "." + } else { + k = s[start+1 : end] + } + end = start + if v := curr.Get(k); v != nil { + if v.IsFinal() { + return + } + curr = v + } else { + next := &domain{children: map[string]Domain{}} + curr.Add(k, next) + curr = next + } + } + if s != "." { + next := &domain{children: map[string]Domain{}, end: true} + curr.Add(s[:end], next) + } else { + curr.Finish() + } +} + +// Get returns child domain by name +func (l *domain) Get(s string) Domain { + return l.children[s] +} + +// NewDomain creates new domain instance +func NewDomain() Domain { + return &domain{children: map[string]Domain{}} +} diff --git a/domain_test.go b/domain_test.go new file mode 100644 index 0000000..8f70a8a --- /dev/null +++ b/domain_test.go @@ -0,0 +1,157 @@ +// Copyright (c) 2020 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fanout + +import ( + "crypto/rand" + "math/big" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDomainBasic(t *testing.T) { + samples := []struct { + child string + parent string + expected bool + }{ + {".", ".", true}, + {"example.org.", ".", true}, + {"example.org.", "example.org.", true}, + {"example.org", "example.org", true}, + {"example.org.", "org.", true}, + {"org.", "example.org.", false}, + } + + for i, s := range samples { + l := NewDomain() + l.AddString(s.parent) + require.Equal(t, s.expected, l.Contains(s.child), i) + } +} + +func TestDomainGet(t *testing.T) { + d := NewDomain() + d.AddString("google.com.") + d.AddString("example.com.") + require.True(t, d.Get(".").Get("com").Get("google").IsFinal()) +} + +func TestDomain_ContainsShouldWorkFast(t *testing.T) { + var samples []string + d := NewDomain() + for i := 0; i < 100; i++ { + for j := 0; j < 100; j++ { + samples = append(samples, genSample(i+1)) + d.AddString(samples[len(samples)-1]) + } + } + start := time.Now() + for i := 0; i < 10000; i++ { + require.True(t, d.Contains(samples[i])) + } + require.True(t, time.Since(start) < time.Second/5) +} + +func TestDomainFewEntries(t *testing.T) { + d := NewDomain() + d.AddString("google.com.") + d.AddString("example.com.") + require.True(t, d.Contains("google.com.")) + require.True(t, d.Contains("example.com.")) + require.False(t, d.Contains("com.")) +} + +func TestDomain_DoNotStoreExtraEntries(t *testing.T) { + d := NewDomain() + d.AddString("example.com.") + d.AddString("advanced.example.com.") + require.Nil(t, d.Get(".").Get("com").Get("example").Get("advanced")) +} + +func BenchmarkDomain_ContainsMatch(b *testing.B) { + d := NewDomain() + var samples []string + for i := 0; i < 100; i++ { + for j := 0; j < 100; j++ { + samples = append(samples, genSample(i+1)) + d.AddString(samples[len(samples)-1]) + } + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 10000; j++ { + d.Contains(samples[j]) + } + } +} + +func BenchmarkDomain_AddString(b *testing.B) { + d := NewDomain() + var samples []string + for i := 0; i < 100; i++ { + for j := 0; j < 100; j++ { + samples = append(samples, genSample(i+1)) + } + } + for i := 0; i < b.N; i++ { + for j := 0; j < len(samples); j++ { + d.AddString(samples[j]) + } + } +} + +func BenchmarkDomain_ContainsAny(b *testing.B) { + d := NewDomain() + var samples []string + for i := 0; i < 100; i++ { + for j := 0; j < 100; j++ { + d.AddString(genSample(i + 1)) + samples = append(samples, genSample(i+1)) + } + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < len(samples); j++ { + d.Contains(samples[j]) + } + } +} + +func genSample(n int) string { + randInt := func() int { + r, err := rand.Int(rand.Reader, big.NewInt(100)) + if err != nil { + panic(err.Error()) + } + return int(r.Int64()) + } + + var sb strings.Builder + for segment := 0; segment < n; segment++ { + l := randInt()%9 + 1 + for i := 0; i < l; i++ { + v := (randInt() % 26) + 97 + _, _ = sb.WriteRune(rune(v)) + } + _, _ = sb.WriteRune('.') + } + return sb.String() +} diff --git a/fanout.go b/fanout.go index 81029c0..2d50b69 100644 --- a/fanout.go +++ b/fanout.go @@ -33,25 +33,26 @@ var log = clog.NewWithPlugin("fanout") // Fanout represents a plugin instance that can do async requests to list of DNS servers. type Fanout struct { - clients []Client - tlsConfig *tls.Config - ignored []string - tlsServerName string - timeout time.Duration - net string - from string - attempts int - workerCount int - Next plugin.Handler + clients []Client + tlsConfig *tls.Config + excludeDomains Domain + tlsServerName string + timeout time.Duration + net string + from string + attempts int + workerCount int + Next plugin.Handler } // New returns reference to new Fanout plugin instance with default configs. func New() *Fanout { return &Fanout{ - tlsConfig: new(tls.Config), - net: "udp", - attempts: 3, - timeout: defaultTimeout, + tlsConfig: new(tls.Config), + net: "udp", + attempts: 3, + timeout: defaultTimeout, + excludeDomains: NewDomain(), } } @@ -135,24 +136,12 @@ func (f *Fanout) getFanoutResult(ctx context.Context, responseCh <-chan *respons } func (f *Fanout) match(state *request.Request) bool { - if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) { + if !plugin.Name(f.from).Matches(state.Name()) || f.excludeDomains.Contains(state.Name()) { return false } return true } -func (f *Fanout) isAllowedDomain(name string) bool { - if dns.Name(name) == dns.Name(f.from) { - return true - } - for _, ignore := range f.ignored { - if plugin.Name(ignore).Matches(name) { - return false - } - } - return true -} - func (f *Fanout) processClient(ctx context.Context, c Client, r *request.Request) *response { start := time.Now() for j := 0; j < f.attempts || f.attempts == 0; <-time.After(attemptDelay) { diff --git a/fanout_test.go b/fanout_test.go index 47325ec..7fab3ff 100644 --- a/fanout_test.go +++ b/fanout_test.go @@ -19,12 +19,17 @@ package fanout import ( "context" "fmt" + "io/ioutil" "net" + "os" + "strings" "sync" "sync/atomic" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/caddyserver/caddy" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/stretchr/testify/suite" @@ -105,6 +110,26 @@ type fanoutTestSuite struct { network string } +func TestFanout_ExceptFile(t *testing.T) { + file, err := ioutil.TempFile(os.TempDir(), t.Name()) + exclude := []string{"example1.com.", "example2.com."} + require.Nil(t, err) + defer func() { + require.Nil(t, os.Remove(file.Name())) + }() + _, err = file.WriteString(strings.Join(exclude, "\n")) + require.Nil(t, err) + source := fmt.Sprintf(`fanout . 0.0.0.0:53 { + except-file %v +}`, file.Name()) + c := caddy.NewTestController("dns", source) + f, err := parseFanout(c) + require.Nil(t, err) + for _, e := range exclude { + require.True(t, f.excludeDomains.Contains(e)) + } +} + func (t *fanoutTestSuite) TestConfigFromCorefile() { s := newServer(t.network, func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) @@ -169,6 +194,7 @@ func (t *fanoutTestSuite) TestWorkerCountLessThenServers() { f.addClient(NewClient(correctServer.addr, t.network)) f.workerCount = 1 + f.attempts = 1 req := new(dns.Msg) req.SetQuestion(testQuery, dns.TypeA) _, err := f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) diff --git a/setup.go b/setup.go index f7d49d8..c912f63 100644 --- a/setup.go +++ b/setup.go @@ -17,6 +17,8 @@ package fanout import ( + "io/ioutil" + "path/filepath" "strconv" "strings" "time" @@ -157,6 +159,8 @@ func parseValue(v string, f *Fanout, c *caddyfile.Dispenser) error { return parseTimeout(f, c) case "except": return parseIgnored(f, c) + case "except-file": + return parseIgnoredFromFile(f, c) case "attempt-count": num, err := parsePositiveInt(c) f.attempts = num @@ -176,15 +180,30 @@ func parseTimeout(f *Fanout, c *caddyfile.Dispenser) error { return err } +func parseIgnoredFromFile(f *Fanout, c *caddyfile.Dispenser) error { + args := c.RemainingArgs() + if len(args) != 1 { + return c.ArgErr() + } + b, err := ioutil.ReadFile(filepath.Clean(args[0])) + if err != nil { + return err + } + names := strings.Split(string(b), "\n") + for i := 0; i < len(names); i++ { + f.excludeDomains.AddString(plugin.Host(names[i]).Normalize()) + } + return nil +} + func parseIgnored(f *Fanout, c *caddyfile.Dispenser) error { ignore := c.RemainingArgs() if len(ignore) == 0 { return c.ArgErr() } for i := 0; i < len(ignore); i++ { - ignore[i] = plugin.Host(ignore[i]).Normalize() + f.excludeDomains.AddString(plugin.Host(ignore[i]).Normalize()) } - f.ignored = ignore return nil } diff --git a/setup_test.go b/setup_test.go index 87dc0f5..1cc7331 100644 --- a/setup_test.go +++ b/setup_test.go @@ -76,8 +76,10 @@ func TestSetup(t *testing.T) { t.Fatalf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from) } if test.expectedIgnored != nil { - if !reflect.DeepEqual(f.ignored, test.expectedIgnored) { - t.Fatalf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored) + for _, expected := range test.expectedIgnored { + if !f.excludeDomains.Contains(expected) { + t.Fatalf("Test %d: missed exclude domain name: %v", i, test.expectedIgnored) + } } } if test.expectedTo != nil {