-
Notifications
You must be signed in to change notification settings - Fork 63
/
dnslistener.go
131 lines (112 loc) · 3.07 KB
/
dnslistener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package rdns
import (
"crypto/tls"
"net"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
// DNSListener is a standard DNS listener for UDP or TCP.
type DNSListener struct {
*dns.Server
id string
}
var _ Listener = &DNSListener{}
type ListenOptions struct {
// Network allowed to query this listener.
AllowedNet []*net.IPNet
}
// NewDNSListener returns an instance of either a UDP or TCP DNS listener.
func NewDNSListener(id, addr, net string, opt ListenOptions, resolver Resolver) *DNSListener {
return &DNSListener{
id: id,
Server: &dns.Server{
Addr: addr,
Net: net,
Handler: listenHandler(id, net, addr, resolver, opt.AllowedNet),
},
}
}
// Start the DNS listener.
func (s DNSListener) Start() error {
Log.WithFields(logrus.Fields{
"id": s.id,
"protocol": s.Net,
"addr": s.Addr}).Info("starting listener")
return s.ListenAndServe()
}
func (s DNSListener) String() string {
return s.id
}
// DNS handler to forward all incoming requests to a given resolver.
func listenHandler(id, protocol, addr string, r Resolver, allowedNet []*net.IPNet) dns.HandlerFunc {
metrics := NewListenerMetrics("listener", id)
return func(w dns.ResponseWriter, req *dns.Msg) {
var err error
ci := ClientInfo{
Listener: id,
}
if r, ok := w.(interface{ ConnectionState() *tls.ConnectionState }); ok {
connState := r.ConnectionState()
if connState != nil {
ci.TLSServerName = connState.ServerName
}
}
switch addr := w.RemoteAddr().(type) {
case *net.TCPAddr:
ci.SourceIP = addr.IP
case *net.UDPAddr:
ci.SourceIP = addr.IP
}
log := Log.WithFields(logrus.Fields{"id": id, "client": ci.SourceIP, "qname": qName(req), "protocol": protocol, "addr": addr})
log.Debug("received query")
metrics.query.Add(1)
a := new(dns.Msg)
if isAllowed(allowedNet, ci.SourceIP) {
log.WithField("resolver", r.String()).Trace("forwarding query to resolver")
a, err = r.Resolve(req, ci)
if err != nil {
metrics.err.Add("resolve", 1)
log.WithError(err).Error("failed to resolve")
a = servfail(req)
}
} else {
metrics.err.Add("acl", 1)
log.Debug("refusing client ip")
a.SetRcode(req, dns.RcodeRefused)
}
// A nil response from the resolvers means "drop", close the connection
if a == nil {
w.Close()
metrics.drop.Add(1)
return
}
// If the client asked via DoT and EDNS0 is enabled, the response should be padded for extra security.
// See rfc7830 and rfc8467.
if protocol == "dot" || protocol == "dtls" {
padAnswer(req, a)
} else {
stripPadding(a)
}
// Check the response actually fits if the query was sent over UDP. If not, respond with TC flag.
if protocol == "udp" || protocol == "dtls" {
maxSize := dns.MinMsgSize
if edns0 := req.IsEdns0(); edns0 != nil {
maxSize = int(edns0.UDPSize())
}
a.Truncate(maxSize)
}
metrics.response.Add(rCode(a), 1)
_ = w.WriteMsg(a)
}
}
func isAllowed(allowedNet []*net.IPNet, ip net.IP) bool {
if len(allowedNet) == 0 {
return true
}
for _, net := range allowedNet {
if net.Contains(ip) {
return true
}
}
return false
}