Skip to content

Commit

Permalink
added void limit to limited resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
csucu committed Aug 18, 2023
1 parent fbee7fa commit 1550a63
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 46 deletions.
2 changes: 1 addition & 1 deletion macro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func TestMacro_Domains(t *testing.T) {
continue
}
t.Run(fmt.Sprintf("%d_%s", no, test.query), func(t *testing.T) {
got, exp, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4)),
got, exp, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4, 2)),
HeloDomain(test.helo),
EvaluatedOn(time.Unix(1, 0)),
ReceivingFQDN(test.receivingFQDN)).
Expand Down
2 changes: 1 addition & 1 deletion parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func newParser(opts ...Option) *parser {
func newParserWithVisited(visited *stringsStack, opts ...Option) *parser {
p := &parser{
// mechanisms: make([]*token, 0, 10),
resolver: NewLimitedResolver(&DNSResolver{}, 10, 10),
resolver: NewLimitedResolver(&DNSResolver{}, 10, 10, 2),
options: opts,
visited: visited,
receivingFQDN: "unknown",
Expand Down
4 changes: 2 additions & 2 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ func TestParse(t *testing.T) {
}
done := make(chan R)
go func() {
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 5, 4))).with(testcase.Query, "matching.com", "matching.com", testcase.IP).check()
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 5, 4, 2))).with(testcase.Query, "matching.com", "matching.com", testcase.IP).check()
done <- R{result, err}
}()
select {
Expand Down Expand Up @@ -1113,7 +1113,7 @@ func TestCheckHost_RecursionLoop(t *testing.T) {
}
done := make(chan R)
go func() {
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4))).with(test.query, "matching.com", "matching.com", test.ip).check()
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4, 2))).with(test.query, "matching.com", "matching.com", test.ip).check()
done <- R{result, err}
}()
select {
Expand Down
68 changes: 56 additions & 12 deletions resolver_limited.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,34 @@ import (
// LimitedResolver wraps a Resolver and limits number of lookups possible to do
// with it. All overlimited calls return ErrDNSLimitExceeded.
type LimitedResolver struct {
lookupLimit int32
mxQueriesLimit uint16
resolver Resolver
lookupLimit int32
mxQueriesLimit uint16
voidLookupLimit int32
resolver Resolver
}

// NewLimitedResolver returns a resolver which will pass up to lookupLimit calls to r.
// In addition to that limit, the evaluation of each "MX" record will be limited
// to mxQueryLimit.
// All calls over the limit will return ErrDNSLimitExceeded.
// Make sure lookupLimit includes the initial SPF lookup
func NewLimitedResolver(r Resolver, lookupLimit, mxQueriesLimit uint16) Resolver {
func NewLimitedResolver(r Resolver, lookupLimit, mxQueriesLimit, voidLookupLimit uint16) Resolver {
return &LimitedResolver{
lookupLimit: int32(lookupLimit), // sure that l is positive or zero
mxQueriesLimit: mxQueriesLimit,
resolver: r,
lookupLimit: int32(lookupLimit), // sure that l is positive or zero
mxQueriesLimit: mxQueriesLimit,
voidLookupLimit: int32(voidLookupLimit),
resolver: r,
}
}

func (r *LimitedResolver) canLookup() bool {
return atomic.AddInt32(&r.lookupLimit, -1) > 0
}

func (r *LimitedResolver) canPerformVoidLookup() bool {
return atomic.AddInt32(&r.voidLookupLimit, -1) > 0
}

// LookupTXT returns the DNS TXT records for the given domain name
// and the minimum TTL. Used for "exp" modifier and do not cause DNS query.
func (r *LimitedResolver) LookupTXT(name string) ([]string, *ResponseExtras, error) {
Expand All @@ -45,7 +51,15 @@ func (r *LimitedResolver) LookupTXTStrict(name string) ([]string, *ResponseExtra
if !r.canLookup() {
return nil, nil, ErrDNSLimitExceeded
}
return r.resolver.LookupTXTStrict(name)

txts, extras, err := r.resolver.LookupTXTStrict(name)
if extras != nil && extras.Void {
if !r.canPerformVoidLookup() {
return nil, nil, ErrDNSVoidLookupLimitExceeded
}
}

return txts, extras, err
}

// Exists is used for a DNS A RR lookup (even when the
Expand All @@ -57,7 +71,15 @@ func (r *LimitedResolver) Exists(name string) (bool, *ResponseExtras, error) {
if !r.canLookup() {
return false, nil, ErrDNSLimitExceeded
}
return r.resolver.Exists(name)

found, extras, err := r.resolver.Exists(name)
if extras != nil && extras.Void {
if !r.canPerformVoidLookup() {
return false, nil, ErrDNSVoidLookupLimitExceeded
}
}

return found, extras, err
}

// MatchIP provides an address lookup, which should be done on the name
Expand All @@ -70,7 +92,15 @@ func (r *LimitedResolver) MatchIP(name string, matcher IPMatcherFunc) (bool, *Re
if !r.canLookup() {
return false, nil, ErrDNSLimitExceeded
}
return r.resolver.MatchIP(name, matcher)

found, extras, err := r.resolver.MatchIP(name, matcher)
if extras != nil && extras.Void {
if !r.canPerformVoidLookup() {
return false, nil, ErrDNSVoidLookupLimitExceeded
}
}

return found, extras, err
}

// MatchMX is similar to MatchIP but first performs an MX lookup on the
Expand All @@ -91,12 +121,19 @@ func (r *LimitedResolver) MatchMX(name string, matcher IPMatcherFunc) (bool, *Re
}

limit := int32(r.mxQueriesLimit)
return r.resolver.MatchMX(name, func(ip net.IP, name string) (bool, error) {
found, extras, err := r.resolver.MatchMX(name, func(ip net.IP, name string) (bool, error) {
if atomic.AddInt32(&limit, -1) < 1 {
return false, ErrDNSLimitExceeded
}
return matcher(ip, name)
})
if extras != nil && extras.Void {
if !r.canPerformVoidLookup() {
return false, nil, ErrDNSVoidLookupLimitExceeded
}
}

return found, extras, err
}

// LookupPTR returns the DNS PTR records for the given domain name
Expand All @@ -105,5 +142,12 @@ func (r *LimitedResolver) LookupPTR(name string) ([]string, *ResponseExtras, err
if !r.canLookup() {
return nil, nil, ErrDNSLimitExceeded
}
return r.resolver.LookupPTR(name)
ptrs, extras, err := r.resolver.LookupPTR(name)
if extras != nil && extras.Void {
if !r.canPerformVoidLookup() {
return nil, nil, ErrDNSVoidLookupLimitExceeded
}
}

return ptrs, extras, err
}
22 changes: 17 additions & 5 deletions resolver_limited_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestLimitedResolver(t *testing.T) {
defer dns.HandleRemove("mxmustfail.")

{
r := NewLimitedResolver(testResolver, 2, 2)
r := NewLimitedResolver(testResolver, 2, 2, 2)
a, _, err := r.LookupTXT("domain.")
if len(a) == 0 || err != nil {
t.Error("failed on 1st LookupTXT")
Expand All @@ -49,7 +49,7 @@ func TestLimitedResolver(t *testing.T) {
}
}
{
r := NewLimitedResolver(testResolver, 2, 2)
r := NewLimitedResolver(testResolver, 2, 2, 2)
b, _, err := r.Exists("domain.")
if !b || err != nil {
t.Error("failed on 1st Exists")
Expand All @@ -65,7 +65,7 @@ func TestLimitedResolver(t *testing.T) {
}
}
{
r := NewLimitedResolver(testResolver, 2, 2)
r := NewLimitedResolver(testResolver, 2, 2, 2)
b, _, err := r.MatchIP("domain.", newMatcher(net.ParseIP("10.0.0.1")))
if !b || err != nil {
t.Error("failed on 1st MatchIP")
Expand All @@ -76,7 +76,7 @@ func TestLimitedResolver(t *testing.T) {
}
}
{
r := NewLimitedResolver(testResolver, 2, 2)
r := NewLimitedResolver(testResolver, 2, 2, 2)
b, _, err := r.MatchMX("domain.", newMatcher(net.ParseIP("10.0.0.1")))
if !b || err != nil {
t.Error("failed on 1st MatchMX")
Expand All @@ -87,10 +87,22 @@ func TestLimitedResolver(t *testing.T) {
}
}
{
r := NewLimitedResolver(testResolver, 2, 2)
r := NewLimitedResolver(testResolver, 2, 2, 2)
b, _, err := r.MatchMX("mxmustfail.", newMatcher(net.ParseIP("10.0.0.10")))
if b || err != ErrDNSLimitExceeded {
t.Errorf("MatchMX got: %v, %v; want false, ErrDNSLimitExceeded", b, err)
}
}
{
dns.HandleFunc("void.test.", Zone(map[uint16][]string{}))
defer dns.HandleRemove("void.test.")

r := NewLimitedResolver(testResolver, 6, 6, 2)
_, _, err := r.LookupTXTStrict("void.test.")
_, _, err = r.LookupTXTStrict("void.test.")
_, _, err = r.LookupTXTStrict("void.test.")
if err != ErrDNSVoidLookupLimitExceeded {
t.Errorf("LookupTXTStrict got: %v; want ErrDNSVoidLookupLimitExceeded", err)
}
}
}
22 changes: 12 additions & 10 deletions resolver_miekg.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,17 @@ func (r *miekgDNSResolver) LookupTXT(name string) ([]string, *ResponseExtras, er
}
}

if len(txts) == 0 {
minTTL = 0
}

extras := &ResponseExtras{
TTL: time.Duration(minTTL) * time.Second,
// We have a void lookup if we have no answers with NoError, or we have NXDomain
Void: (len(res.Answer) == 0 && res.Rcode == dns.RcodeSuccess) || res.Rcode == dns.RcodeNameError,
}

if len(txts) == 0 {
minTTL = 0
}

extras.TTL = time.Duration(minTTL) * time.Second

return txts, extras, nil
}

Expand Down Expand Up @@ -232,16 +233,17 @@ func (r *miekgDNSResolver) LookupTXTStrict(name string) ([]string, *ResponseExtr
}
}

if len(txts) == 0 {
minTTL = 0
}

extras := &ResponseExtras{
TTL: time.Duration(minTTL) * time.Second,
// We have a void lookup if we have no answers with NoError
Void: len(res.Answer) == 0 && res.Rcode == dns.RcodeSuccess,
}

if len(txts) == 0 {
minTTL = 0
}

extras.TTL = time.Duration(minTTL) * time.Second

return txts, extras, nil
}

Expand Down
31 changes: 16 additions & 15 deletions spf.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@ import (

// Errors could be used for root couse analysis
var (
ErrDNSTemperror = errors.New("temporary DNS error")
ErrDNSPermerror = errors.New("permanent DNS error")
ErrDNSLimitExceeded = errors.New("limit exceeded")
ErrSPFNotFound = errors.New("SPF record not found")
ErrInvalidCIDRLength = errors.New("invalid CIDR length")
ErrTooManySPFRecords = errors.New("too many SPF records")
ErrTooManyRedirects = errors.New(`too many "redirect"`)
ErrTooManyExps = errors.New(`too many "exp"`)
ErrSyntaxError = errors.New(`wrong syntax`)
ErrEmptyDomain = errors.New("empty domain")
ErrNotIPv4 = errors.New("address isn't ipv4")
ErrNotIPv6 = errors.New("address isn't ipv6")
ErrLoopDetected = errors.New("infinite recursion detected")
ErrUnreliableResult = errors.New("result is unreliable with IgnoreMatches option enabled")
ErrTooManyErrors = errors.New("too many errors")
ErrDNSTemperror = errors.New("temporary DNS error")
ErrDNSPermerror = errors.New("permanent DNS error")
ErrDNSLimitExceeded = errors.New("limit exceeded")
ErrDNSVoidLookupLimitExceeded = errors.New("void lookup limit exceeded")
ErrSPFNotFound = errors.New("SPF record not found")
ErrInvalidCIDRLength = errors.New("invalid CIDR length")
ErrTooManySPFRecords = errors.New("too many SPF records")
ErrTooManyRedirects = errors.New(`too many "redirect"`)
ErrTooManyExps = errors.New(`too many "exp"`)
ErrSyntaxError = errors.New(`wrong syntax`)
ErrEmptyDomain = errors.New("empty domain")
ErrNotIPv4 = errors.New("address isn't ipv4")
ErrNotIPv6 = errors.New("address isn't ipv6")
ErrLoopDetected = errors.New("infinite recursion detected")
ErrUnreliableResult = errors.New("result is unreliable with IgnoreMatches option enabled")
ErrTooManyErrors = errors.New("too many errors")
)

// DomainError represents a domain check error
Expand Down

0 comments on commit 1550a63

Please sign in to comment.