Skip to content

Commit

Permalink
all: improve api
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Dec 8, 2020
1 parent dfe7e0e commit d33afd8
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 320 deletions.
49 changes: 26 additions & 23 deletions dnsengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,23 @@ type DNSResult struct {
HostRulesV4 []*rules.HostRule // host rules for IPv4 or nil
HostRulesV6 []*rules.HostRule // host rules for IPv6 or nil

// DNSRewriteNetworkRules are the DNS rewrite rules set with $dnsrewrite
// modifiers.
DNSRewriteNetworkRules []*rules.NetworkRule
// networkRules are all matched network rules.
networkRules []*rules.NetworkRule
}

// DNSRewrites returns all $dnsrewrite network rules.
func (res *DNSResult) DNSRewrites() (rules []*rules.NetworkRule) {
if res == nil {
return nil
}

for _, nr := range res.networkRules {
if nr.DNSRewrite != nil {
rules = append(rules, nr)
}
}

return rules
}

// DNSRequest represents a DNS query with associated metadata.
Expand Down Expand Up @@ -109,7 +123,7 @@ func (d *DNSEngine) Match(hostname string) (DNSResult, bool) {
// For instance:
// 192.168.0.1 example.local
// 2000::1 example.local
func (d *DNSEngine) MatchRequest(dReq DNSRequest) (DNSResult, bool) {
func (d *DNSEngine) MatchRequest(dReq DNSRequest) (res DNSResult, matched bool) {
if dReq.Hostname == "" {
return DNSResult{}, false
}
Expand All @@ -120,45 +134,34 @@ func (d *DNSEngine) MatchRequest(dReq DNSRequest) (DNSResult, bool) {
r.ClientName = dReq.ClientName
r.DNSType = dReq.DNSType

networkRules := d.networkEngine.MatchAll(r)
res.networkRules = d.networkEngine.MatchAll(r)

var dnsRewriteRules []*rules.NetworkRule
for _, nr := range networkRules {
if nr.DNSRewrite != nil {
dnsRewriteRules = append(dnsRewriteRules, nr)
}
}

if len(dnsRewriteRules) > 0 {
// DNS rewrite rules have a higher priority.
return DNSResult{
DNSRewriteNetworkRules: dnsRewriteRules,
}, true
}

result := rules.NewMatchingResult(networkRules, nil)
result := rules.NewMatchingResult(res.networkRules, nil)
resultRule := result.GetBasicResult()
if resultRule != nil {
// Network rules always have higher priority
return DNSResult{NetworkRule: resultRule}, true
// Network rules always have higher priority.
res.NetworkRule = resultRule
return res, true
}

rr, ok := d.matchLookupTable(dReq.Hostname)
if !ok {
return DNSResult{}, false
}
res := DNSResult{}

for _, rule := range rr {
hostRule, ok := rule.(*rules.HostRule)
if !ok {
continue
}

if hostRule.IP.To4() != nil {
res.HostRulesV4 = append(res.HostRulesV4, hostRule)
} else {
res.HostRulesV6 = append(res.HostRulesV6, hostRule)
}
}

return res, true
}

Expand Down
2 changes: 1 addition & 1 deletion dnsengine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestBenchDNSEngine(t *testing.T) {
if err != nil {
t.Fatalf("cannot create rule storage: %s", err)
}
defer ruleStorage.Close()
defer func() { assert.Nil(t, ruleStorage.Close()) }()

testRequests := loadRequests(t)
assert.True(t, len(testRequests) > 0)
Expand Down
112 changes: 60 additions & 52 deletions dnsrewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
|new_cname^$dnsrewrite=othercname
|new_txt^$dnsrewrite=NOERROR;TXT;new_txtcontent
|priority^$client=127.0.0.1
|priority^$dnsrewrite=127.0.0.1
|https_type^$dnstype=HTTPS,dnsrewrite=REFUSED
|disable_one^$dnsrewrite=127.0.0.1
Expand Down Expand Up @@ -59,8 +56,10 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("short_v4", func(t *testing.T) {
res, ok := dnsEngine.Match("short_v4")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p1, nr.DNSRewrite.Value)
Expand All @@ -70,13 +69,15 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("short_v4_multiple", func(t *testing.T) {
res, ok := dnsEngine.Match("short_v4_multiple")
assert.True(t, ok)
if assert.Equal(t, 2, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 2, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p1, nr.DNSRewrite.Value)

nr = res.DNSRewriteNetworkRules[1]
nr = dnsr[1]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p2, nr.DNSRewrite.Value)
Expand All @@ -86,8 +87,10 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("normal_v4", func(t *testing.T) {
res, ok := dnsEngine.Match("normal_v4")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p1, nr.DNSRewrite.Value)
Expand All @@ -97,13 +100,15 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("normal_v4_multiple", func(t *testing.T) {
res, ok := dnsEngine.Match("normal_v4_multiple")
assert.True(t, ok)
if assert.Equal(t, 2, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 2, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p1, nr.DNSRewrite.Value)

nr = res.DNSRewriteNetworkRules[1]
nr = dnsr[1]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p2, nr.DNSRewrite.Value)
Expand All @@ -113,8 +118,10 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("short_v6", func(t *testing.T) {
res, ok := dnsEngine.Match("short_v6")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeAAAA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv6p1, nr.DNSRewrite.Value)
Expand All @@ -124,13 +131,15 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("short_v6_multiple", func(t *testing.T) {
res, ok := dnsEngine.Match("short_v6_multiple")
assert.True(t, ok)
if assert.Equal(t, 2, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 2, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeAAAA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv6p1, nr.DNSRewrite.Value)

nr = res.DNSRewriteNetworkRules[1]
nr = dnsr[1]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeAAAA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv6p2, nr.DNSRewrite.Value)
Expand All @@ -140,8 +149,10 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("normal_v6", func(t *testing.T) {
res, ok := dnsEngine.Match("normal_v6")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeAAAA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv6p1, nr.DNSRewrite.Value)
Expand All @@ -151,13 +162,15 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("normal_v6_multiple", func(t *testing.T) {
res, ok := dnsEngine.Match("normal_v6_multiple")
assert.True(t, ok)
if assert.Equal(t, 2, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 2, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeAAAA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv6p1, nr.DNSRewrite.Value)

nr = res.DNSRewriteNetworkRules[1]
nr = dnsr[1]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeAAAA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv6p2, nr.DNSRewrite.Value)
Expand All @@ -167,26 +180,32 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
t.Run("refused_host", func(t *testing.T) {
res, ok := dnsEngine.Match("refused_host")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeRefused, nr.DNSRewrite.RCode)
}
})

t.Run("new_cname", func(t *testing.T) {
res, ok := dnsEngine.Match("new_cname")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, "othercname", nr.DNSRewrite.NewCNAME)
}
})

t.Run("new_txt", func(t *testing.T) {
res, ok := dnsEngine.Match("new_txt")
assert.True(t, ok)
if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]

dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeTXT, nr.DNSRewrite.RRType)
assert.Equal(t, "new_txtcontent", nr.DNSRewrite.Value)
Expand All @@ -202,8 +221,9 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
res, ok := dnsEngine.MatchRequest(r)
assert.True(t, ok)

if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]
dnsr := res.DNSRewrites()
if assert.Equal(t, 1, len(dnsr)) {
nr := dnsr[0]
assert.Equal(t, dns.RcodeRefused, nr.DNSRewrite.RCode)
}

Expand All @@ -216,28 +236,14 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
assert.False(t, ok)
})

t.Run("priority", func(t *testing.T) {
res, ok := dnsEngine.Match("priority")
assert.True(t, ok)
assert.Nil(t, res.NetworkRule)
assert.Nil(t, res.HostRulesV4)
assert.Nil(t, res.HostRulesV6)

if assert.Equal(t, 1, len(res.DNSRewriteNetworkRules)) {
nr := res.DNSRewriteNetworkRules[0]
assert.Equal(t, dns.RcodeSuccess, nr.DNSRewrite.RCode)
assert.Equal(t, dns.TypeA, nr.DNSRewrite.RRType)
assert.Equal(t, ipv4p1, nr.DNSRewrite.Value)
}
})

t.Run("disable_one", func(t *testing.T) {
res, ok := dnsEngine.Match("disable_one")
assert.True(t, ok)

var allowListCase *rules.NetworkRule
if assert.Equal(t, 3, len(res.DNSRewriteNetworkRules)) {
for _, r := range res.DNSRewriteNetworkRules {
dnsr := res.DNSRewrites()
if assert.Equal(t, 3, len(dnsr)) {
for _, r := range dnsr {
if r.Whitelist {
allowListCase = r
}
Expand All @@ -257,8 +263,9 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
assert.True(t, ok)

var allowListCase *rules.NetworkRule
if assert.Equal(t, 3, len(res.DNSRewriteNetworkRules)) {
for _, r := range res.DNSRewriteNetworkRules {
dnsr := res.DNSRewrites()
if assert.Equal(t, 3, len(dnsr)) {
for _, r := range dnsr {
if r.Whitelist {
allowListCase = r
}
Expand All @@ -275,8 +282,9 @@ func TestDNSEngine_MatchRequest_dnsRewrite(t *testing.T) {
assert.True(t, ok)

var allowListCase *rules.NetworkRule
if assert.Equal(t, 3, len(res.DNSRewriteNetworkRules)) {
for _, r := range res.DNSRewriteNetworkRules {
dnsr := res.DNSRewrites()
if assert.Equal(t, 3, len(dnsr)) {
for _, r := range dnsr {
if r.Whitelist {
allowListCase = r
}
Expand Down
2 changes: 1 addition & 1 deletion rules/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func newClients(clientStrs ...string) *clients {
}

// containsAny - checks if "clients" contains host or ipStr
func (c *clients) containsAny(host string, ipStr string) bool {
func (c *clients) containsAny(host, ipStr string) bool {
if c == nil {
return false
}
Expand Down
Loading

0 comments on commit d33afd8

Please sign in to comment.