Skip to content

Commit

Permalink
Merge branch 'master' of github.com:redsift/spf into feature/bed-336
Browse files Browse the repository at this point in the history
 Conflicts:
	listener.go
	parser.go
	printer/printer.go
	printer/printer_test.go
  • Loading branch information
csucu committed Sep 5, 2023
2 parents a4808e3 + cea6f3a commit d01b6eb
Show file tree
Hide file tree
Showing 15 changed files with 578 additions and 253 deletions.
2 changes: 1 addition & 1 deletion lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func lex(input string) []*token {
return tokens
}

// scan scans input and returns a Token structure
// scan scans input and returns a token structure
func (l *lexer) scan() *token {
for {
r, eof := l.next()
Expand Down
8 changes: 5 additions & 3 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ package spf

import (
"net"
"time"
)

type Listener interface {
CheckHost(ip net.IP, domain, sender string)
CheckHostResult(r Result, explanation string, ttl time.Duration, err error)
CheckHostResult(r Result, explanation string, extras *ResponseExtras, err error)
SPFRecord(s string)
Directive(unused bool, qualifier, mechanism, value, effectiveValue string)
NonMatch(qualifier, mechanism, value string, result Result, err error)
Match(qualifier, mechanism, value string, result Result, explanation string, ttl time.Duration, err error)
Match(qualifier, mechanism, value string, result Result, explanation string, extras *ResponseExtras, err error)
FirstMatch(r Result, err error)
MatchingIP(qualifier, mechanism, value string, fqdn string, ipn net.IPNet, host string, ip net.IP)
// VoidLookup Should only be called after a Directive or CheckHost call, to ensure count is updated to correct
// directive and state is correct
VoidLookup(qualifier, mechanism, value string, fqdn string)
}
2 changes: 1 addition & 1 deletion macro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func TestMacro_Domains(t *testing.T) {
continue
}
t.Run(fmt.Sprintf("%d_%s", no, test.query), func(t *testing.T) {
got, exp, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4)),
got, exp, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4, 2)),
HeloDomain(test.helo),
EvaluatedOn(time.Unix(1, 0)),
ReceivingFQDN(test.receivingFQDN)).
Expand Down
109 changes: 70 additions & 39 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func newParser(opts ...Option) *parser {
func newParserWithVisited(visited *stringsStack, fireFirstMatchOnce *sync.Once, opts ...Option) *parser {
p := &parser{
// mechanisms: make([]*token, 0, 10),
resolver: NewLimitedResolver(&DNSResolver{}, 10, 10),
resolver: NewLimitedResolver(&DNSResolver{}, 10, 10, 2),
options: opts,
visited: visited,
receivingFQDN: "unknown",
Expand All @@ -167,10 +167,10 @@ func newParserWithVisited(visited *stringsStack, fireFirstMatchOnce *sync.Once,
// and error as the reason for the encountered problem.
func (p *parser) checkHost(ip net.IP, domain, sender string) (r Result, expl string, spf string, err error) {
var u unused
var ttl time.Duration
var extras *ResponseExtras
p.fireCheckHost(ip, domain, sender)
defer func() {
p.fireCheckHostResult(r, expl, ttl, err)
p.fireCheckHostResult(r, expl, extras, err)
for _, t := range u.mechanisms {
p.fireUnusedDirective(t)
}
Expand All @@ -192,7 +192,11 @@ func (p *parser) checkHost(ip net.IP, domain, sender string) (r Result, expl str
}

var txts []string
txts, ttl, err = p.resolver.LookupTXTStrict(NormalizeFQDN(domain))
txts, extras, err = p.resolver.LookupTXTStrict(NormalizeFQDN(domain))
if extras.Void() {
p.fireVoidLookup(nil, domain)
}

switch err {
case nil:
// continue
Expand Down Expand Up @@ -249,7 +253,7 @@ func (p *parser) check() (Result, string, unused, error) {
matches bool
token *token
i int
ttl time.Duration
extras *ResponseExtras
)

mechanisms, redirect, explanation, err := sortTokens(tokens)
Expand All @@ -266,19 +270,19 @@ func (p *parser) check() (Result, string, unused, error) {
all = true
matches, result, err = p.parseAll(token)
case tA:
matches, result, ttl, err = p.parseA(token)
matches, result, extras, err = p.parseA(token)
case tIP4:
matches, result, err = p.parseIP4(token)
case tIP6:
matches, result, err = p.parseIP6(token)
case tMX:
matches, result, ttl, err = p.parseMX(token)
matches, result, extras, err = p.parseMX(token)
case tInclude:
matches, result, err = p.parseInclude(token)
case tExists:
matches, result, ttl, err = p.parseExists(token)
matches, result, extras, err = p.parseExists(token)
case tPTR:
matches, result, err = p.parsePtr(token)
matches, result, extras, err = p.parsePtr(token)
default:
p.fireDirective(token, "")
}
Expand All @@ -288,7 +292,7 @@ func (p *parser) check() (Result, string, unused, error) {
if result == Fail && explanation != nil {
s, err = p.handleExplanation(explanation)
}
p.fireMatch(token, result, s, ttl, err)
p.fireMatch(token, result, s, extras, err)
return result, s, unused{mechanisms[i+1:], redirect}, err
}

Expand Down Expand Up @@ -326,11 +330,11 @@ func (p *parser) fireCheckHost(ip net.IP, domain, sender string) {
p.listener.CheckHost(ip, domain, sender)
}

func (p *parser) fireCheckHostResult(r Result, explanation string, ttl time.Duration, e error) {
func (p *parser) fireCheckHostResult(r Result, explanation string, extras *ResponseExtras, e error) {
if p.listener == nil {
return
}
p.listener.CheckHostResult(r, explanation, ttl, e)
p.listener.CheckHostResult(r, explanation, extras, e)
}

func (p *parser) fireSPFRecord(s string) {
Expand Down Expand Up @@ -368,11 +372,24 @@ func (p *parser) fireNonMatch(t *token, r Result, e error) {
p.listener.NonMatch(t.qualifier.String(), t.mechanism.String(), t.value, r, e)
}

func (p *parser) fireMatch(t *token, r Result, explanation string, ttl time.Duration, e error) {
func (p *parser) fireMatch(t *token, r Result, explanation string, extras *ResponseExtras, e error) {
if p.listener == nil {
return
}
p.listener.Match(t.qualifier.String(), t.mechanism.String(), t.value, r, explanation, ttl, e)
p.listener.Match(t.qualifier.String(), t.mechanism.String(), t.value, r, explanation, extras, e)
}

func (p *parser) fireVoidLookup(t *token, fqdn string) {
if p.listener == nil {
return
}

if t == nil {
p.listener.VoidLookup("", "", "", fqdn)
return
}

p.listener.VoidLookup(t.qualifier.String(), t.mechanism.String(), t.value, fqdn)
}

func (p *parser) fireFirstMatch(r Result, e error) {
Expand Down Expand Up @@ -488,7 +505,7 @@ func (p *parser) parseIP6(t *token) (bool, Result, error) {
return ip.Equal(p.ip), result, nil
}

func (p *parser) parseA(t *token) (bool, Result, time.Duration, error) {
func (p *parser) parseA(t *token) (bool, Result, *ResponseExtras, error) {
fqdn, ip4Mask, ip6Mask, err := splitDomainDualCIDR(domainSpec(t.value, p.domain))
if err == nil {
fqdn, _, err = parseMacro(p, fqdn, false)
Expand All @@ -502,12 +519,12 @@ func (p *parser) parseA(t *token) (bool, Result, time.Duration, error) {
fqdn = NormalizeFQDN(fqdn)
p.fireDirective(t, fqdn)
if err != nil {
return true, Permerror, 0, SpfError{Syntax, t, err}
return true, Permerror, nil, SpfError{Syntax, t, err}
}

result, _ := matchingResult(t.qualifier)

found, ttl, err := p.resolver.MatchIP(fqdn, func(ip net.IP, host string) (bool, error) {
found, extras, err := p.resolver.MatchIP(fqdn, func(ip net.IP, host string) (bool, error) {
n := net.IPNet{
IP: ip,
}
Expand All @@ -520,13 +537,16 @@ func (p *parser) parseA(t *token) (bool, Result, time.Duration, error) {
p.fireMatchingIP(t, fqdn, n, host, p.ip)
return n.Contains(p.ip), nil
})
if extras.Void() {
p.fireVoidLookup(t, fqdn)
}
if err != nil {
return found, result, ttl, SpfError{kind: DNS, err: err}
return found, result, nil, SpfError{kind: DNS, err: err}
}
return found, result, ttl, err
return found, result, extras, err
}

func (p *parser) parseMX(t *token) (bool, Result, time.Duration, error) {
func (p *parser) parseMX(t *token) (bool, Result, *ResponseExtras, error) {
fqdn, ip4Mask, ip6Mask, err := splitDomainDualCIDR(domainSpec(t.value, p.domain))
if err == nil {
fqdn, _, err = parseMacro(p, fqdn, false)
Expand All @@ -540,11 +560,11 @@ func (p *parser) parseMX(t *token) (bool, Result, time.Duration, error) {
fqdn = NormalizeFQDN(fqdn)
p.fireDirective(t, fqdn)
if err != nil {
return true, Permerror, 0, SpfError{Syntax, t, err}
return true, Permerror, nil, SpfError{Syntax, t, err}
}

result, _ := matchingResult(t.qualifier)
found, ttl, err := p.resolver.MatchMX(fqdn, func(ip net.IP, host string) (bool, error) {
found, extras, err := p.resolver.MatchMX(fqdn, func(ip net.IP, host string) (bool, error) {
n := net.IPNet{
IP: ip,
}
Expand All @@ -557,10 +577,13 @@ func (p *parser) parseMX(t *token) (bool, Result, time.Duration, error) {
p.fireMatchingIP(t, fqdn, n, host, p.ip)
return n.Contains(p.ip), nil
})
if extras.Void() {
p.fireVoidLookup(t, fqdn)
}
if err != nil {
return true, Permerror, 0, SpfError{DNS, t, err}
return true, Permerror, nil, SpfError{DNS, t, err}
}
return found, result, ttl, err
return found, result, extras, err
}

func (p *parser) parseInclude(t *token) (bool, Result, error) {
Expand Down Expand Up @@ -626,7 +649,7 @@ func (p *parser) parseInclude(t *token) (bool, Result, error) {
}
}

func (p *parser) parseExists(t *token) (bool, Result, time.Duration, error) {
func (p *parser) parseExists(t *token) (bool, Result, *ResponseExtras, error) {
resolvedDomain, missingMacros, err := parseMacroToken(p, t)
if err == nil {
resolvedDomain, err = truncateFQDN(resolvedDomain)
Expand All @@ -641,27 +664,31 @@ func (p *parser) parseExists(t *token) (bool, Result, time.Duration, error) {
resolvedDomain = NormalizeFQDN(resolvedDomain)
p.fireDirective(t, resolvedDomain)
if err != nil {
return true, Permerror, 0, SpfError{Syntax, t, err}
return true, Permerror, nil, SpfError{Syntax, t, err}
}
if resolvedDomain == "" {
return true, Permerror, 0, SpfError{Syntax, t, ErrEmptyDomain}
return true, Permerror, nil, SpfError{Syntax, t, ErrEmptyDomain}
}

result, _ := matchingResult(t.qualifier)

found, ttl, err := p.resolver.Exists(resolvedDomain)
found, extras, err := p.resolver.Exists(resolvedDomain)
if extras.Void() {
p.fireVoidLookup(t, resolvedDomain)
}

switch err {
case nil:
return found, result, ttl, nil
return found, result, extras, nil
case ErrDNSPermerror:
return false, result, 0, nil
return false, result, nil, nil
default:
return false, Temperror, 0, SpfError{kind: DNS, err: err} // was true 8-|
return false, Temperror, nil, SpfError{kind: DNS, err: err} // was true 8-|
}
}

// https://www.rfc-editor.org/rfc/rfc7208#section-5.5
func (p *parser) parsePtr(t *token) (bool, Result, error) {
func (p *parser) parsePtr(t *token) (bool, Result, *ResponseExtras, error) {
fqdn := domainSpec(t.value, p.domain)
fqdn, _, err := parseMacro(p, fqdn, false)
if err == nil {
Expand All @@ -673,19 +700,23 @@ func (p *parser) parsePtr(t *token) (bool, Result, error) {
fqdn = NormalizeFQDN(fqdn)
p.fireDirective(t, fqdn)
if err != nil {
return true, Permerror, SpfError{Syntax, t, err}
return true, Permerror, nil, SpfError{Syntax, t, err}
}

ptrs, extras, err := p.resolver.LookupPTR(p.ip.String())
if extras.Void() {
p.fireVoidLookup(t, fqdn)
}

ptrs, _, err := p.resolver.LookupPTR(p.ip.String())
switch err {
case nil:
// continue
case ErrDNSLimitExceeded:
return false, Permerror, SpfError{kind: DNS, err: err}
return false, Permerror, extras, SpfError{kind: DNS, err: err}
case ErrDNSPermerror:
return false, None, SpfError{kind: DNS, err: err}
return false, None, extras, SpfError{kind: DNS, err: err}
default:
return false, Temperror, SpfError{kind: DNS, err: err}
return false, Temperror, extras, SpfError{kind: DNS, err: err}
}

result, _ := matchingResult(t.qualifier)
Expand All @@ -705,11 +736,11 @@ func (p *parser) parsePtr(t *token) (bool, Result, error) {
}

if found {
return true, result, nil
return true, result, nil, nil
}
}

return false, Fail, nil
return false, Fail, nil, nil
}

func (p *parser) handleRedirect(t *token) (Result, error) {
Expand Down
6 changes: 3 additions & 3 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ func TestParseMXNegativeTests(t *testing.T) {
testcases := []TokenTestCase{
{&token{tMX, qPlus, "matching.com"}, Pass, false, false},
{&token{tMX, qPlus, ""}, Pass, false, false},
// TokenTestCase{&Token{tMX, qPlus, "google.com"}, Pass, false},
// TokenTestCase{&token{tMX, qPlus, "google.com"}, Pass, false},
{&token{tMX, qPlus, "idontexist"}, Pass, false, false},
{&token{tMX, qMinus, "matching.com"}, Fail, false, false},

Expand Down Expand Up @@ -1035,7 +1035,7 @@ func TestParse(t *testing.T) {
}
done := make(chan R)
go func() {
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 5, 4))).with(testcase.Query, "matching.com", "matching.com", testcase.IP).check()
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 5, 4, 2))).with(testcase.Query, "matching.com", "matching.com", testcase.IP).check()
done <- R{result, err}
}()
select {
Expand Down Expand Up @@ -1105,7 +1105,7 @@ func TestCheckHost_RecursionLoop(t *testing.T) {
}
done := make(chan R)
go func() {
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4))).with(test.query, "matching.com", "matching.com", test.ip).check()
result, _, _, err := newParser(WithResolver(NewLimitedResolver(testResolver, 4, 4, 2))).with(test.query, "matching.com", "matching.com", test.ip).check()
done <- R{result, err}
}()
select {
Expand Down
Loading

0 comments on commit d01b6eb

Please sign in to comment.