Skip to content

Commit

Permalink
Merge pull request #281 from LevenLabs/master
Browse files Browse the repository at this point in the history
Instead of removing all RRs on Truncated, attempt to unpack
  • Loading branch information
miekg committed Nov 2, 2015
2 parents 497abb0 + 2d2c2eb commit d274557
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 58 deletions.
6 changes: 6 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ func (co *Conn) ReadMsg() (*Msg, error) {

m := new(Msg)
if err := m.Unpack(p); err != nil {
// If ErrTruncated was returned, we still want to allow the user to use
// the message, but naively they can just check err if they don't want
// to use a truncated message
if err == ErrTruncated {
return m, err
}
return nil, err
}
if t := m.IsTsig(); t != nil {
Expand Down
145 changes: 145 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dns

import (
"fmt"
"net"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -236,3 +238,146 @@ func TestClientConn(t *testing.T) {
t.Errorf("unable to unpack message fully: %v", err)
}
}

func TestTruncatedMsg(t *testing.T) {
m := new(Msg)
m.SetQuestion("miek.nl.", TypeSRV)
cnt := 10
for i := 0; i < cnt; i++ {
r := &SRV{
Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeSRV, Class: ClassINET, Ttl: 0},
Port: uint16(i + 8000),
Target: "target.miek.nl.",
}
m.Answer = append(m.Answer, r)

re := &A{
Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeA, Class: ClassINET, Ttl: 0},
A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i)).To4(),
}
m.Extra = append(m.Extra, re)
}
buf, err := m.Pack()
if err != nil {
t.Errorf("failed to pack: %v", err)
}

r := new(Msg)
if err = r.Unpack(buf); err != nil {
t.Errorf("unable to unpack message: %v", err)
}
if len(r.Answer) != cnt {
t.Logf("answer count after regular unpack doesn't match: %d", len(r.Answer))
t.Fail()
}
if len(r.Extra) != cnt {
t.Logf("extra count after regular unpack doesn't match: %d", len(r.Extra))
t.Fail()
}

m.Truncated = true
buf, err = m.Pack()
if err != nil {
t.Errorf("failed to pack truncated: %v", err)
}

r = new(Msg)
if err = r.Unpack(buf); err != nil && err != ErrTruncated {
t.Errorf("unable to unpack truncated message: %v", err)
}
if !r.Truncated {
t.Log("truncated message wasn't unpacked as truncated")
t.Fail()
}
if len(r.Answer) != cnt {
t.Logf("answer count after truncated unpack doesn't match: %d", len(r.Answer))
t.Fail()
}
if len(r.Extra) != cnt {
t.Logf("extra count after truncated unpack doesn't match: %d", len(r.Extra))
t.Fail()
}

// Now we want to remove almost all of the extra records
// We're going to loop over the extra to get the count of the size of all
// of them
off := 0
buf1 := make([]byte, m.Len())
for i := 0; i < len(m.Extra); i++ {
off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
if err != nil {
t.Errorf("failed to pack extra: %v", err)
}
}

// Remove all of the extra bytes but 10 bytes from the end of buf
off -= 10
buf1 = buf[:len(buf)-off]

r = new(Msg)
if err = r.Unpack(buf1); err != nil && err != ErrTruncated {
t.Errorf("unable to unpack cutoff message: %v", err)
}
if !r.Truncated {
t.Log("truncated cutoff message wasn't unpacked as truncated")
t.Fail()
}
if len(r.Answer) != cnt {
t.Logf("answer count after cutoff unpack doesn't match: %d", len(r.Answer))
t.Fail()
}
if len(r.Extra) != 0 {
t.Logf("extra count after cutoff unpack is not zero: %d", len(r.Extra))
t.Fail()
}

// Now we want to remove almost all of the answer records too
buf1 = make([]byte, m.Len())
as := 0
for i := 0; i < len(m.Extra); i++ {
off1 := off
off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
as = off - off1
if err != nil {
t.Errorf("failed to pack extra: %v", err)
}
}

// Keep exactly one answer left
// This should still cause Answer to be nil
off -= as
buf1 = buf[:len(buf)-off]

r = new(Msg)
if err = r.Unpack(buf1); err != nil && err != ErrTruncated {
t.Errorf("unable to unpack cutoff message: %v", err)
}
if !r.Truncated {
t.Log("truncated cutoff message wasn't unpacked as truncated")
t.Fail()
}
if len(r.Answer) != 0 {
t.Logf("answer count after second cutoff unpack is not zero: %d", len(r.Answer))
t.Fail()
}

// Now leave only 1 byte of the question
// Since the header is always 12 bytes, we just need to keep 13
buf1 = buf[:13]

r = new(Msg)
err = r.Unpack(buf1)
if err == nil || err == ErrTruncated {
t.Logf("error should not be ErrTruncated from question cutoff unpack: %v", err)
t.Fail()
}

// Finally, if we only have the header, we should still return an error
buf1 = buf[:12]

r = new(Msg)
if err = r.Unpack(buf1); err == nil || err != ErrTruncated {
t.Logf("error not ErrTruncated from header-only unpack: %v", err)
t.Fail()
}
}
109 changes: 51 additions & 58 deletions msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ var (
ErrSoa error = &Error{err: "no SOA"}
// ErrTime indicates a timing error in TSIG authentication.
ErrTime error = &Error{err: "bad time"}
// ErrTruncated indicates that we failed to unpack a truncated message.
// We unpacked as much as we had so Msg can still be used, if desired.
ErrTruncated error = &Error{err: "failed to unpack truncated message"}
)

// Id, by default, returns a 16 bits random number to be used as a
Expand Down Expand Up @@ -1238,8 +1241,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
continue
}
}
if off == lenmsg {
// zero rdata foo, OK for dyn. updates
if off == lenmsg && int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) == 0 {
// zero rdata is ok for dyn updates, but only if rdlength is 0
break
}
s, off, err = UnpackDomainName(msg, off)
Expand Down Expand Up @@ -1396,6 +1399,32 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
return rr, off, err
}

// unpackRRslice unpacks msg[off:] into an []RR.
// If we cannot unpack the whole array, then it will return nil
func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
var r RR
// Optimistically make dst be the length that was sent
dst := make([]RR, 0, l)
for i := 0; i < l; i++ {
off1 := off
r, off, err = UnpackRR(msg, off)
if err != nil {
off = len(msg)
break
}
// If offset does not increase anymore, l is a lie
if off1 == off {
l = i
break
}
dst = append(dst, r)
}
if err != nil && off == len(msg) {
dst = nil
}
return dst, off, err
}

// Reverse a map
func reverseInt8(m map[uint8]string) map[string]uint8 {
n := make(map[string]uint8)
Expand Down Expand Up @@ -1594,84 +1623,48 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
dns.CheckingDisabled = (dh.Bits & _CD) != 0
dns.Rcode = int(dh.Bits & 0xF)

// Don't pre-alloc these arrays, the incoming lengths are from the network.
dns.Question = make([]Question, 0, 1)
dns.Answer = make([]RR, 0, 10)
dns.Ns = make([]RR, 0, 10)
dns.Extra = make([]RR, 0, 10)
// Optimistically use the count given to us in the header
dns.Question = make([]Question, 0, int(dh.Qdcount))

var q Question
for i := 0; i < int(dh.Qdcount); i++ {
off1 := off
off, err = UnpackStruct(&q, msg, off)
if err != nil {
// Even if Truncated is set, we only will set ErrTruncated if we
// actually got the questions
return err
}
if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
dh.Qdcount = uint16(i)
break
}

dns.Question = append(dns.Question, q)

}
// If we see a TC bit being set we return here, without
// an error, because technically it isn't an error. So return
// without parsing the potentially corrupt packet and hitting an error.
// TODO(miek): this isn't the best strategy!
// Better stragey would be: set boolean indicating truncated message, go forth and parse
// until we hit an error, return the message without the latest parsed rr if this boolean
// is true.
if dns.Truncated {
dns.Answer = nil
dns.Ns = nil
dns.Extra = nil
return nil
}

var r RR
for i := 0; i < int(dh.Ancount); i++ {
off1 := off
r, off, err = UnpackRR(msg, off)
if err != nil {
return err
}
if off1 == off { // Offset does not increase anymore, dh.Ancount is a lie!
dh.Ancount = uint16(i)
break
}
dns.Answer = append(dns.Answer, r)
}
for i := 0; i < int(dh.Nscount); i++ {
off1 := off
r, off, err = UnpackRR(msg, off)
if err != nil {
return err
}
if off1 == off { // Offset does not increase anymore, dh.Nscount is a lie!
dh.Nscount = uint16(i)
break
}
dns.Ns = append(dns.Ns, r)
dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
// The header counts might have been wrong so we need to update it
dh.Ancount = uint16(len(dns.Answer))
if err == nil {
dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
}
for i := 0; i < int(dh.Arcount); i++ {
off1 := off
r, off, err = UnpackRR(msg, off)
if err != nil {
return err
}
if off1 == off { // Offset does not increase anymore, dh.Arcount is a lie!
dh.Arcount = uint16(i)
break
}
dns.Extra = append(dns.Extra, r)
// The header counts might have been wrong so we need to update it
dh.Nscount = uint16(len(dns.Ns))
if err == nil {
dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
}
// The header counts might have been wrong so we need to update it
dh.Arcount = uint16(len(dns.Extra))
if off != len(msg) {
// TODO(miek) make this an error?
// use PackOpt to let people tell how detailed the error reporting should be?
// println("dns: extra bytes in dns packet", off, "<", len(msg))
} else if dns.Truncated {
// Whether we ran into a an error or not, we want to return that it
// was truncated
err = ErrTruncated
}
return nil
return err
}

// Convert a complete message to a string with dig-like output.
Expand Down

0 comments on commit d274557

Please sign in to comment.