From d4e356613d2615d38963f33c3c27029f7a52cef8 Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Fri, 31 May 2024 17:36:53 -0400 Subject: [PATCH 1/2] Add a hook to catch invalid messages Currently there are hooks for reading messages off the wire (DecorateReader), checking if they comply with policy (MsgAcceptFunc), and generating responses (Handler). However, there is no hook that notifies the server when a message is dropped or rejected due to a syntax error. That makes it hard to monitor these packets without repeating the parsing process. This PR adds a hook for notifications about invalid packets. --- acceptfunc_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++ server.go | 19 ++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/acceptfunc_test.go b/acceptfunc_test.go index d40d4e4cd..4bcbd12ad 100644 --- a/acceptfunc_test.go +++ b/acceptfunc_test.go @@ -1,6 +1,8 @@ package dns import ( + "encoding/binary" + "net" "testing" ) @@ -33,3 +35,86 @@ func handleNotify(w ResponseWriter, req *Msg) { m.SetReply(req) w.WriteMsg(m) } + +func TestInvalidMsg(t *testing.T) { + HandleFunc("example.org.", func(ResponseWriter, *Msg) { + t.Fatal("the handler must not be called in any of these tests") + }) + s, addrstr, _, err := RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + s.MsgAcceptFunc = func(dh Header) MsgAcceptAction { + switch dh.Id { + case 0x0001: + return MsgAccept + case 0x0002: + return MsgReject + case 0x0003: + return MsgIgnore + case 0x0004: + return MsgRejectNotImplemented + default: + t.Errorf("unexpected ID %x", dh.Id) + return -1 + } + } + + invalidErrors := make(chan error) + s.InvalidMsgFunc = func(m []byte, err error) { + invalidErrors <- err + } + + c, err := net.Dial("tcp", addrstr) + if err != nil { + t.Fatalf("cannot connect to test server: %v", err) + } + + write := func(m []byte) { + var length [2]byte + binary.BigEndian.PutUint16(length[:], uint16(len(m))) + _, err := c.Write(length[:]) + if err != nil { + t.Fatalf("length write failed: %v", err) + } + _, err = c.Write(m) + if err != nil { + t.Fatalf("content write failed: %v", err) + } + } + + /* Message is too short, so there is no header to accept or reject. */ + + tooShortMessage := make([]byte, 11) + tooShortMessage[1] = 0x3 // ID = 3, would be ignored if it were parsable. + + write(tooShortMessage) + // Expect an error to be reported. + <-invalidErrors + + /* Message is accepted but is actually invalid. */ + + badMessage := make([]byte, 13) + badMessage[1] = 0x1 // ID = 1, Accept. + badMessage[5] = 1 // QDCOUNT = 1 + badMessage[12] = 99 // Bad question section. Invalid! + + write(badMessage) + // Expect an error to be reported. + <-invalidErrors + + /* Message is rejected before it can be determined to be invalid. */ + + close(invalidErrors) // A call to InvalidMsgFunc would panic due to the closed chan. + + badMessage[1] = 0x2 // ID = 2, Reject + write(badMessage) + + badMessage[1] = 0x3 // ID = 3, Ignore + write(badMessage) + + badMessage[1] = 0x4 // ID = 4, RejectNotImplemented + write(badMessage) +} diff --git a/server.go b/server.go index 0207d6da2..e95e4da78 100644 --- a/server.go +++ b/server.go @@ -188,6 +188,14 @@ type DecorateReader func(Reader) Reader // Implementations should never return a nil Writer. type DecorateWriter func(Writer) Writer +// InvalidMsgFunc is a listener hook for observing incoming messages that were discarded +// because they could not be parsed. +// Every message that is read by a Reader will eventually be provided to the Handler, +// rejected (or ignored) by the MsgAcceptFunc, or passed to this function. +type InvalidMsgFunc func(m []byte, err error) + +func DefaultInvalidMsgFunc(m []byte, err error) {} + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -233,6 +241,8 @@ type Server struct { // AcceptMsgFunc will check the incoming message and will reject it early in the process. // By default DefaultMsgAcceptFunc will be used. MsgAcceptFunc MsgAcceptFunc + // InvalidMsgFunc is optional, will be called if a message is received but cannot be parsed. + InvalidMsgFunc InvalidMsgFunc // Shutdown handling lock sync.RWMutex @@ -277,6 +287,9 @@ func (srv *Server) init() { if srv.MsgAcceptFunc == nil { srv.MsgAcceptFunc = DefaultMsgAcceptFunc } + if srv.InvalidMsgFunc == nil { + srv.InvalidMsgFunc = DefaultInvalidMsgFunc + } if srv.Handler == nil { srv.Handler = DefaultServeMux } @@ -531,6 +544,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error { if cap(m) == srv.UDPSize { srv.udpPool.Put(m[:srv.UDPSize]) } + srv.InvalidMsgFunc(m, ErrShortRead) continue } wg.Add(1) @@ -611,6 +625,7 @@ func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn func (srv *Server) serveDNS(m []byte, w *response) { dh, off, err := unpackMsgHdr(m, 0) if err != nil { + srv.InvalidMsgFunc(m, err) // Let client hang, they are sending crap; any reply can be used to amplify. return } @@ -620,10 +635,12 @@ func (srv *Server) serveDNS(m []byte, w *response) { switch action := srv.MsgAcceptFunc(dh); action { case MsgAccept: - if req.unpack(dh, m, off) == nil { + err := req.unpack(dh, m, off) + if err == nil { break } + srv.InvalidMsgFunc(m, err) fallthrough case MsgReject, MsgRejectNotImplemented: opcode := req.Opcode From caf25b32f2ccab0ab3107b21e4ac660c0ebb49fe Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Thu, 13 Jun 2024 11:07:16 -0400 Subject: [PATCH 2/2] s/InvalidMsg/MsgInvalid/g --- acceptfunc_test.go | 2 +- server.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/acceptfunc_test.go b/acceptfunc_test.go index 4bcbd12ad..868154d40 100644 --- a/acceptfunc_test.go +++ b/acceptfunc_test.go @@ -63,7 +63,7 @@ func TestInvalidMsg(t *testing.T) { } invalidErrors := make(chan error) - s.InvalidMsgFunc = func(m []byte, err error) { + s.MsgInvalidFunc = func(m []byte, err error) { invalidErrors <- err } diff --git a/server.go b/server.go index e95e4da78..2f7655645 100644 --- a/server.go +++ b/server.go @@ -194,7 +194,7 @@ type DecorateWriter func(Writer) Writer // rejected (or ignored) by the MsgAcceptFunc, or passed to this function. type InvalidMsgFunc func(m []byte, err error) -func DefaultInvalidMsgFunc(m []byte, err error) {} +func DefaultMsgInvalidFunc(m []byte, err error) {} // A Server defines parameters for running an DNS server. type Server struct { @@ -241,8 +241,8 @@ type Server struct { // AcceptMsgFunc will check the incoming message and will reject it early in the process. // By default DefaultMsgAcceptFunc will be used. MsgAcceptFunc MsgAcceptFunc - // InvalidMsgFunc is optional, will be called if a message is received but cannot be parsed. - InvalidMsgFunc InvalidMsgFunc + // MsgInvalidFunc is optional, will be called if a message is received but cannot be parsed. + MsgInvalidFunc InvalidMsgFunc // Shutdown handling lock sync.RWMutex @@ -287,8 +287,8 @@ func (srv *Server) init() { if srv.MsgAcceptFunc == nil { srv.MsgAcceptFunc = DefaultMsgAcceptFunc } - if srv.InvalidMsgFunc == nil { - srv.InvalidMsgFunc = DefaultInvalidMsgFunc + if srv.MsgInvalidFunc == nil { + srv.MsgInvalidFunc = DefaultMsgInvalidFunc } if srv.Handler == nil { srv.Handler = DefaultServeMux @@ -544,7 +544,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error { if cap(m) == srv.UDPSize { srv.udpPool.Put(m[:srv.UDPSize]) } - srv.InvalidMsgFunc(m, ErrShortRead) + srv.MsgInvalidFunc(m, ErrShortRead) continue } wg.Add(1) @@ -625,7 +625,7 @@ func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn func (srv *Server) serveDNS(m []byte, w *response) { dh, off, err := unpackMsgHdr(m, 0) if err != nil { - srv.InvalidMsgFunc(m, err) + srv.MsgInvalidFunc(m, err) // Let client hang, they are sending crap; any reply can be used to amplify. return } @@ -640,7 +640,7 @@ func (srv *Server) serveDNS(m []byte, w *response) { break } - srv.InvalidMsgFunc(m, err) + srv.MsgInvalidFunc(m, err) fallthrough case MsgReject, MsgRejectNotImplemented: opcode := req.Opcode