Skip to content

Commit

Permalink
Added UDPSize properties to client and server
Browse files Browse the repository at this point in the history
  • Loading branch information
ameshkov committed Jul 13, 2021
1 parent b5bcf75 commit 3416ca5
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 15 deletions.
32 changes: 29 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import (
type Client struct {
Net string // protocol (can be "udp" or "tcp", by default - "udp")
Timeout time.Duration // read/write timeout

// UDPSize is the maximum size of a DNS response (or query) this client can
// sent or receive. If not set, we use dns.MinMsgSize by default.
UDPSize int
}

// ResolverInfo contains DNSCrypt resolver information necessary for decryption/encryption
Expand Down Expand Up @@ -158,7 +162,11 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
}

if proto == "udp" {
response := make([]byte, maxQueryLen)
bufSize := c.UDPSize
if bufSize == 0 {
bufSize = dns.MinMsgSize
}
response := make([]byte, bufSize)
n, err := conn.Read(response)
if err != nil {
return nil, err
Expand All @@ -182,7 +190,12 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error)
if err != nil {
return nil, err
}
return q.Encrypt(query, resolverInfo.SharedKey)
b, err := q.Encrypt(query, resolverInfo.SharedKey)
if len(b) > c.maxQuerySize() {
return nil, ErrQueryTooLarge
}

return b, err
}

// decrypts decrypts a DNS message using a shared key from the resolver info
Expand Down Expand Up @@ -212,7 +225,8 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {

query := new(dns.Msg)
query.SetQuestion(providerName, dns.TypeTXT)
client := dns.Client{Net: c.Net, UDPSize: uint16(maxQueryLen), Timeout: c.Timeout}
// use 1252 as a UDPSize for this client to make sure the buffer is not too small
client := dns.Client{Net: c.Net, UDPSize: uint16(1252), Timeout: c.Timeout}
r, _, err := client.Exchange(query, stamp.ServerAddrStr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -284,3 +298,15 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {

return nil, certErr
}

func (c *Client) maxQuerySize() int {
if c.Net == "tcp" {
return dns.MaxMsgSize
}

if c.UDPSize > 0 {
return c.UDPSize
}

return dns.MinMsgSize
}
3 changes: 0 additions & 3 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ const (
// Some servers do not work if padded length is less than 256. Example: Quad9
minUDPQuestionSize = 256

// <max-query-len> is the maximum allowed query length
maxQueryLen = 1252

// Minimum possible DNS packet size
minDNSPacketSize = 12 + 5

Expand Down
4 changes: 0 additions & 4 deletions encrypted_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) (
return nil, ErrEsVersion
}

if len(query) > maxQueryLen {
return nil, ErrQueryTooLarge
}

return query, nil
}

Expand Down
8 changes: 8 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ type Server struct {
// ResolverCert contains resolver certificate.
ResolverCert *Cert

// UDPSize is the default buffer size to use to read incoming UDP messages.
// If not set it defaults to dns.MinMsgSize (512 B).
UDPSize int

// Handler to invoke. If nil, uses DefaultHandler.
Handler Handler

Expand Down Expand Up @@ -148,6 +152,10 @@ func (s *Server) init() {
s.tcpConns = map[net.Conn]struct{}{}
s.udpListeners = map[*net.UDPConn]struct{}{}
s.tcpListeners = map[net.Listener]struct{}{}

if s.UDPSize == 0 {
s.UDPSize = dns.MinMsgSize
}
}

// isStarted returns true if the server is processing queries right now
Expand Down
2 changes: 1 addition & 1 deletion server_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (w *TCPResponseWriter) RemoteAddr() net.Addr {

// WriteMsg writes DNS message to the client
func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
m.Truncate(dnsSize("tcp", w.req))
normalize("tcp", w.req, m)

res, err := w.encrypt(m, w.query)
if err != nil {
Expand Down
108 changes: 107 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,85 @@ func TestServer_ReadTimeout(t *testing.T) {
testThisServerRespondMessages(t, "tcp", srv)
}

func TestServer_UDPTruncateMessage(t *testing.T) {
// Create a test server that returns large response which should be
// truncated if sent over UDP
srv := newTestServer(t, &testLargeMsgHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})

// Create client and connect
client := &Client{
Timeout: 1 * time.Second,
Net: "udp",
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(t, err)
require.NotNil(t, ri)

// Send a test message and check that the response was truncated
m := createTestMessage()
res, err := client.Exchange(m, ri)
require.NoError(t, err)
require.NotNil(t, res)
require.Equal(t, dns.RcodeSuccess, res.Rcode)
require.Len(t, res.Answer, 0)
require.True(t, res.Truncated)
}

func TestServer_UDPEDNS0_NoTruncate(t *testing.T) {
// Create a test server that returns large response which should be
// truncated if sent over UDP
// However, when EDNS0 is set with the buffer large enough, there should
// be no truncation
srv := newTestServer(t, &testLargeMsgHandler{})
t.Cleanup(func() {
require.NoError(t, srv.Close())
})

// Create client and connect
client := &Client{
Timeout: 1 * time.Second,
Net: "udp",
UDPSize: 7000, // make sure the client will be able to read the response
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(t, err)
require.NotNil(t, ri)

// Send a test message with UDP buffer size large enough
// and check that the response was NOT truncated
m := createTestMessage()
m.Extra = append(m.Extra, &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
Class: 2000, // Set large enough UDPSize here
},
})
res, err := client.Exchange(m, ri)
require.NoError(t, err)
require.NotNil(t, res)
require.Equal(t, dns.RcodeSuccess, res.Rcode)
require.Len(t, res.Answer, 64)
require.False(t, res.Truncated)
}

func testServerServeCert(t *testing.T, network string) {
srv := newTestServer(t, &testHandler{})
t.Cleanup(func() {
Expand Down Expand Up @@ -193,17 +272,44 @@ type testHandler struct{}

// ServeDNS - implements Handler interface
func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
// Google DNS
res := new(dns.Msg)
res.SetReply(r)

answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Ttl: 300,
Class: dns.ClassINET,
}
// First record is from Google DNS
answer.A = net.IPv4(8, 8, 8, 8)
res.Answer = append(res.Answer, answer)

return rw.WriteMsg(res)
}

// testLargeMsgHandler is a handler that returns a huge response
// used for testing messages truncation
type testLargeMsgHandler struct{}

// ServeDNS - implements Handler interface
func (h *testLargeMsgHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
res := new(dns.Msg)
res.SetReply(r)

for i := 0; i < 64; i++ {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Ttl: 300,
Class: dns.ClassINET,
}
answer.A = net.IPv4(127, 0, 0, byte(i))
res.Answer = append(res.Answer, answer)
}

res.Compress = true
return rw.WriteMsg(res)
}
6 changes: 3 additions & 3 deletions server_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error)
type UDPResponseWriter struct {
udpConn *net.UDPConn // UDP connection
sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP)
encrypt encryptionFunc // DNSCRypt encryption function
encrypt encryptionFunc // DNSCrypt encryption function
req *dns.Msg // DNS query that was processed
query EncryptedQuery // DNSCrypt query properties
}
Expand All @@ -40,7 +40,7 @@ func (w *UDPResponseWriter) RemoteAddr() net.Addr {

// WriteMsg writes DNS message to the client
func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
m.Truncate(dnsSize("udp", w.req))
normalize("udp", w.req, m)

res, err := w.encrypt(m, w.query)
if err != nil {
Expand Down Expand Up @@ -157,7 +157,7 @@ func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) {
// readUDPMsg reads incoming UDP message
func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) {
_ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout))
b := make([]byte, dns.MinMsgSize)
b := make([]byte, s.UDPSize)
n, sess, err := dns.ReadFromSessionUDP(l, b)
if err != nil {
return nil, nil, err
Expand Down
19 changes: 19 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,25 @@ func unpackTxtString(s string) ([]byte, error) {
return msg, nil
}

// normalize truncates the DNS response if needed depending on the protocol
func normalize(proto string, req *dns.Msg, res *dns.Msg) {
size := dnsSize(proto, req)
// DNSCrypt encryption adds a header to each message, we should
// consider this when truncating a message.
// 64 should cover all cases
size = size - 64

// Truncate response message
res.Truncate(size)

// In case of UDP it is safer to simply remove all response records
// dns.Msg.Truncate method will not consider that we need a response
// shorter than dns.MinMsgSize
if res.Truncated && proto == "udp" {
res.Answer = nil
}
}

// dnsSize returns if buffer size *advertised* in the requests OPT record.
// Or when the request was over TCP, we return the maximum allowed size of 64K.
func dnsSize(proto string, r *dns.Msg) int {
Expand Down

0 comments on commit 3416ca5

Please sign in to comment.