diff --git a/labels.go b/labels.go index f9faacfeb4..8a6d3487d8 100644 --- a/labels.go +++ b/labels.go @@ -1,5 +1,9 @@ package dns +import ( + "strings" +) + // Holds a bunch of helper functions for dealing with labels. // SplitDomainName splits a name string into it's labels. @@ -186,6 +190,45 @@ func PrevLabel(s string, n int) (i int, start bool) { return 0, n > 1 } +// Compare compares domains according to the canonical ordering specified in RFC4034 +// returns an integer value similar to strcmp +// (0 for equal values, -1 if s1 < s2, 1 if s1 > s2) +func Compare(s1, s2 string) int { + s1b := []byte(s1) + s2b := []byte(s2) + + doDDD(s1b) + doDDD(s2b) + + s1lend := len(s1) + s2lend := len(s2) + + for i := 0; ; i++ { + s1lstart, end1 := PrevLabel(s1, i) + s2lstart, end2 := PrevLabel(s2, i) + + if end1 && end2 { + return 0 + } + + s1l := string(s1b[s1lstart:s1lend]) + s2l := string(s2b[s2lstart:s2lend]) + + if !equal(s1l, s2l) { + return strings.Compare(strings.ToLower(s1l), strings.ToLower(s2l)) + } + + s1lend = s1lstart - 1 + s2lend = s2lstart - 1 + if s1lend == -1 { + s1lend = 0 + } + if s2lend == -1 { + s2lend = 0 + } + } +} + // equal compares a and b while ignoring case. It returns true when equal otherwise false. func equal(a, b string) bool { // might be lifted into API function. @@ -210,3 +253,16 @@ func equal(a, b string) bool { } return true } + +func doDDD(b []byte) { + lb := len(b) + for i := 0; i < lb; i++ { + if i+3 < lb && b[i] == '\\' && isDigit(b[i+1]) && isDigit(b[i+2]) && isDigit(b[i+3]) { + b[i] = dddToByte(b[i+1 : i+4]) + for j := i + 1; j < lb-3; j++ { + b[j] = b[j+3] + } + lb -= 3 + } + } +} diff --git a/labels_test.go b/labels_test.go index 3e672fec82..968c2e58a8 100644 --- a/labels_test.go +++ b/labels_test.go @@ -334,3 +334,40 @@ func BenchmarkPrevLabelMixed(b *testing.B) { PrevLabel(`www\\\.example.com`, 10) } } + +func TestCompare(t *testing.T) { + domains := []string{ // based on an exanple from RFC 4034 + "example.", + "a.example.", + "yljkjljk.a.example.", + "Z.a.example.", + "zABC.a.EXAMPLE.", + "a-.example.", + "z.example.", + "\001.z.example.", + "*.z.example.", + "\200.z.example.", + } + + len_domains := len(domains) + + for i, domain := range domains { + if i != 0 { + prev_domain := domains[i-1] + if !(Compare(prev_domain, domain) == -1 && Compare(domain, prev_domain) == 1) { + t.Fatalf("prev comparison failure between %s and %s", prev_domain, domain) + } + } + + if Compare(domain, domain) != 0 { + t.Fatalf("self comparison failure for %s", domain) + } + + if i != len_domains-1 { + next_domain := domains[i+1] + if !(Compare(domain, next_domain) == -1 && Compare(next_domain, domain) == 1) { + t.Fatalf("next comparison failure between %s and %s, %d and %d", domain, next_domain, Compare(domain, next_domain), Compare(next_domain, domain)) + } + } + } +} diff --git a/nsecx.go b/nsecx.go index f8826817b3..b53ef35e09 100644 --- a/nsecx.go +++ b/nsecx.go @@ -93,3 +93,8 @@ func (rr *NSEC3) Match(name string) bool { } return false } + +// Match returns true if the given name is covered by the NSEC record +func (rr *NSEC) Cover(name string) bool { + return Compare(rr.Hdr.Name, name) <= 0 && Compare(name, rr.NextDomain) == -1 +} diff --git a/nsecx_test.go b/nsecx_test.go index ee92653343..ca42b37b72 100644 --- a/nsecx_test.go +++ b/nsecx_test.go @@ -168,3 +168,23 @@ func BenchmarkHashName(b *testing.B) { }) } } + +func TestNsecCover(t *testing.T) { + nsec := testRR("aaa.ee. 3600 IN NSEC aac.ee. NS RRSIG NSEC").(*NSEC) + + if !nsec.Cover("aaaa.ee.") { + t.Fatal("nsec cover not covering in-range name") + } + + if !nsec.Cover("aaa.ee.") { + t.Fatal("nsec cover not covering start of range") + } + + if nsec.Cover("aac.ee.") { + t.Fatal("nsec cover range end failure") + } + + if nsec.Cover("aad.ee.") { + t.Fatal("nsec cover covering out-of-range name") + } +}