Skip to content

Commit

Permalink
Merge pull request #17 from denis-tingajkin/improve_except
Browse files Browse the repository at this point in the history
Fanout plugin: Improve except option
  • Loading branch information
haiodo authored Apr 9, 2020
2 parents 47a5b58 + 911d341 commit 89e1e9d
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 31 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Each incoming DNS query that hits the CoreDNS fanout plugin will be replicated i
* `worker-count` is the number of parallel queries per request. By default equals to count of IP list. Use this only for reducing parallel queries per request.
* `network` is a specific network protocol. Could be `tcp`, `udp`, `tcp-tls`.
* `except` is a list is a space-separated list of domains to exclude from proxying.
* `except-file` is the path to file with line-separated list of domains to exclude from proxying.
* `attempt-count` is the number of attempts to connect to upstream servers that are needed before considering an upstream to be down. If 0, the upstream will never be marked as down and request will be finished by `timeout`. Default is `3`.
* `timeout` is the timeout of request. After this period, attempts to receive a response from the upstream servers will be stopped. Default is `30s`.
## Metrics
Expand Down
119 changes: 119 additions & 0 deletions domain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2020 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package fanout

import (
"strings"
)

// Domain represents DNS domain name
type Domain interface {
Get(string) Domain
AddString(string)
Add(string, Domain)
Contains(string) bool
IsFinal() bool
Finish()
}

type domain struct {
children map[string]Domain
end bool
}

// Finish marks current domain as last in the chain
func (l *domain) Finish() {
l.end = true
}

// Add adds domain by name
func (l *domain) Add(n string, d Domain) {
l.children[n] = d
}

// IsFinal returns true if this domain is last in the chain
func (l *domain) IsFinal() bool {
return l.end
}

// Contains parses string and check is domains contains
func (l *domain) Contains(s string) bool {
end := len(s)
var curr Domain = l
for start := strings.LastIndex(s, "."); start != -1; start = strings.LastIndex(s[:start], ".") {
var k string
if start == end-1 {
k = "."
} else {
k = s[start+1 : end]
}
end = start
curr = curr.Get(k)
if curr == nil {
return false
}
if curr.IsFinal() {
return true
}
}
curr = curr.Get(s[:end])
if curr == nil {
return false
}
return curr.IsFinal()
}

// AddString parses string and adds child domains
func (l *domain) AddString(s string) {
end := len(s)
var curr = Domain(l)
for start := strings.LastIndex(s, "."); start != -1; start = strings.LastIndex(s[:start], ".") {
var k string
if start == end-1 {
k = "."
} else {
k = s[start+1 : end]
}
end = start
if v := curr.Get(k); v != nil {
if v.IsFinal() {
return
}
curr = v
} else {
next := &domain{children: map[string]Domain{}}
curr.Add(k, next)
curr = next
}
}
if s != "." {
next := &domain{children: map[string]Domain{}, end: true}
curr.Add(s[:end], next)
} else {
curr.Finish()
}
}

// Get returns child domain by name
func (l *domain) Get(s string) Domain {
return l.children[s]
}

// NewDomain creates new domain instance
func NewDomain() Domain {
return &domain{children: map[string]Domain{}}
}
157 changes: 157 additions & 0 deletions domain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright (c) 2020 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package fanout

import (
"crypto/rand"
"math/big"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestDomainBasic(t *testing.T) {
samples := []struct {
child string
parent string
expected bool
}{
{".", ".", true},
{"example.org.", ".", true},
{"example.org.", "example.org.", true},
{"example.org", "example.org", true},
{"example.org.", "org.", true},
{"org.", "example.org.", false},
}

for i, s := range samples {
l := NewDomain()
l.AddString(s.parent)
require.Equal(t, s.expected, l.Contains(s.child), i)
}
}

func TestDomainGet(t *testing.T) {
d := NewDomain()
d.AddString("google.com.")
d.AddString("example.com.")
require.True(t, d.Get(".").Get("com").Get("google").IsFinal())
}

func TestDomain_ContainsShouldWorkFast(t *testing.T) {
var samples []string
d := NewDomain()
for i := 0; i < 100; i++ {
for j := 0; j < 100; j++ {
samples = append(samples, genSample(i+1))
d.AddString(samples[len(samples)-1])
}
}
start := time.Now()
for i := 0; i < 10000; i++ {
require.True(t, d.Contains(samples[i]))
}
require.True(t, time.Since(start) < time.Second/5)
}

func TestDomainFewEntries(t *testing.T) {
d := NewDomain()
d.AddString("google.com.")
d.AddString("example.com.")
require.True(t, d.Contains("google.com."))
require.True(t, d.Contains("example.com."))
require.False(t, d.Contains("com."))
}

func TestDomain_DoNotStoreExtraEntries(t *testing.T) {
d := NewDomain()
d.AddString("example.com.")
d.AddString("advanced.example.com.")
require.Nil(t, d.Get(".").Get("com").Get("example").Get("advanced"))
}

func BenchmarkDomain_ContainsMatch(b *testing.B) {
d := NewDomain()
var samples []string
for i := 0; i < 100; i++ {
for j := 0; j < 100; j++ {
samples = append(samples, genSample(i+1))
d.AddString(samples[len(samples)-1])
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 10000; j++ {
d.Contains(samples[j])
}
}
}

func BenchmarkDomain_AddString(b *testing.B) {
d := NewDomain()
var samples []string
for i := 0; i < 100; i++ {
for j := 0; j < 100; j++ {
samples = append(samples, genSample(i+1))
}
}
for i := 0; i < b.N; i++ {
for j := 0; j < len(samples); j++ {
d.AddString(samples[j])
}
}
}

func BenchmarkDomain_ContainsAny(b *testing.B) {
d := NewDomain()
var samples []string
for i := 0; i < 100; i++ {
for j := 0; j < 100; j++ {
d.AddString(genSample(i + 1))
samples = append(samples, genSample(i+1))
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < len(samples); j++ {
d.Contains(samples[j])
}
}
}

func genSample(n int) string {
randInt := func() int {
r, err := rand.Int(rand.Reader, big.NewInt(100))
if err != nil {
panic(err.Error())
}
return int(r.Int64())
}

var sb strings.Builder
for segment := 0; segment < n; segment++ {
l := randInt()%9 + 1
for i := 0; i < l; i++ {
v := (randInt() % 26) + 97
_, _ = sb.WriteRune(rune(v))
}
_, _ = sb.WriteRune('.')
}
return sb.String()
}
43 changes: 16 additions & 27 deletions fanout.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,26 @@ var log = clog.NewWithPlugin("fanout")

// Fanout represents a plugin instance that can do async requests to list of DNS servers.
type Fanout struct {
clients []Client
tlsConfig *tls.Config
ignored []string
tlsServerName string
timeout time.Duration
net string
from string
attempts int
workerCount int
Next plugin.Handler
clients []Client
tlsConfig *tls.Config
excludeDomains Domain
tlsServerName string
timeout time.Duration
net string
from string
attempts int
workerCount int
Next plugin.Handler
}

// New returns reference to new Fanout plugin instance with default configs.
func New() *Fanout {
return &Fanout{
tlsConfig: new(tls.Config),
net: "udp",
attempts: 3,
timeout: defaultTimeout,
tlsConfig: new(tls.Config),
net: "udp",
attempts: 3,
timeout: defaultTimeout,
excludeDomains: NewDomain(),
}
}

Expand Down Expand Up @@ -135,24 +136,12 @@ func (f *Fanout) getFanoutResult(ctx context.Context, responseCh <-chan *respons
}

func (f *Fanout) match(state *request.Request) bool {
if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) {
if !plugin.Name(f.from).Matches(state.Name()) || f.excludeDomains.Contains(state.Name()) {
return false
}
return true
}

func (f *Fanout) isAllowedDomain(name string) bool {
if dns.Name(name) == dns.Name(f.from) {
return true
}
for _, ignore := range f.ignored {
if plugin.Name(ignore).Matches(name) {
return false
}
}
return true
}

func (f *Fanout) processClient(ctx context.Context, c Client, r *request.Request) *response {
start := time.Now()
for j := 0; j < f.attempts || f.attempts == 0; <-time.After(attemptDelay) {
Expand Down
Loading

0 comments on commit 89e1e9d

Please sign in to comment.