-
Notifications
You must be signed in to change notification settings - Fork 3
/
proxy.go
230 lines (204 loc) · 5.11 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
// Package dnsproxy contains the DNS proxy.
package dnsproxy
import (
"math/rand"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// Proxy is a DNS proxy.
type Proxy struct {
config *Config
udp *dns.Server
tcp *dns.Server
remotes *Remotes
watches *Watches
reports chan *Report
done chan struct{}
closed chan struct{}
}
// handleRequest handles a dns client request.
func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
// make sure the client request is valid
if len(r.Question) != 1 {
// TODO: be less strict? send error reply to client?
log.WithField("request", r).Error("DNS-Proxy received invalid client request")
return
}
// forward request to remote server and get reply
remotes := p.remotes.Get(r.Question[0].Name)
if len(remotes) == 0 {
log.WithField("name", r.Question[0].Name).
Error("DNS-Proxy has no remotes for question name")
// TODO: send error reply to client?
return
}
// pick random remote server
// TODO: query all servers and take fastest reply?
remote := remotes[rand.Intn(len(remotes))]
reply, err := dns.Exchange(r, remote)
if err != nil {
log.WithError(err).Debug("DNS-Proxy DNS exchange error")
return
}
// parse answers in reply from remote server
for _, a := range reply.Answer {
name := a.Header().Name
if !p.watches.Contains(r.Question[0].Name) &&
!p.watches.Contains(name) {
// not on watch list, ignore answer
continue
}
// get TTL
ttl := a.Header().Ttl
switch a.Header().Rrtype {
case dns.TypeA:
// A Record, get IPv4 address
rr, ok := a.(*dns.A)
if !ok {
log.Error("DNS-Proxy received invalid A record in reply")
continue
}
report := NewReport(name, rr.A, ttl)
p.reports <- report
report.Wait()
case dns.TypeAAAA:
// AAAA Record, get IPv6 address
rr, ok := a.(*dns.AAAA)
if !ok {
log.Error("DNS-Proxy received invalid AAAA record in reply")
continue
}
report := NewReport(name, rr.AAAA, ttl)
p.reports <- report
report.Wait()
case dns.TypeCNAME:
// CNAME record, store temporary watch
rr, ok := a.(*dns.CNAME)
if !ok {
log.Error("DNS-Proxy received invalid CNAME record in reply")
continue
}
log.WithFields(log.Fields{
"target": rr.Target,
"ttl": ttl,
}).Debug("DNS-Proxy received CNAME in reply")
p.watches.AddTemp(rr.Target, ttl)
case dns.TypeDNAME:
// DNAME record, store temporary watch
rr, ok := a.(*dns.DNAME)
if !ok {
log.Error("DNS-Proxy received invalid DNAME record in reply")
continue
}
log.WithFields(log.Fields{
"target": rr.Target,
"ttl": ttl,
}).Debug("DNS-Proxy received DNAME in reply")
p.watches.AddTemp(rr.Target, ttl)
}
}
// send reply to client
if err := w.WriteMsg(reply); err != nil {
log.WithError(err).Error("DNS-Proxy could not forward reply")
}
}
// startDNSServer starts the dns server.
func (p *Proxy) startDNSServer(server *dns.Server) {
if server == nil {
return
}
log.WithFields(log.Fields{
"addr": server.Addr,
"net": server.Net,
}).Debug("DNS-Proxy starting server")
err := server.ListenAndServe()
if err != nil {
log.WithError(err).Error("DNS-Proxy DNS server stopped")
}
}
// stopDNSServer stops the dns server.
func (p *Proxy) stopDNSServer(server *dns.Server) {
if server == nil {
return
}
err := server.Shutdown()
if err != nil {
log.WithFields(log.Fields{
"addr": server.Addr,
"net": server.Net,
"error": err,
}).Error("DNS-Proxy could not stop DNS server")
}
}
// start starts running the proxy.
func (p *Proxy) start() {
defer close(p.closed)
defer close(p.reports)
defer p.watches.Close()
// start dns servers
log.Debug("DNS-Proxy registering handler")
dns.HandleFunc(".", p.handleRequest)
for _, srv := range []*dns.Server{p.udp, p.tcp} {
go p.startDNSServer(srv)
}
// wait for proxy termination
<-p.done
// stop dns servers
for _, srv := range []*dns.Server{p.udp, p.tcp} {
p.stopDNSServer(srv)
}
}
// Start starts running the proxy.
func (p *Proxy) Start() {
go p.start()
}
// Stop stops running the proxy.
func (p *Proxy) Stop() {
close(p.done)
<-p.closed
}
// Reports returns the Report channel for watched domains.
func (p *Proxy) Reports() chan *Report {
return p.reports
}
// SetRemotes sets the mapping from domain names to remote server addresses.
func (p *Proxy) SetRemotes(remotes map[string][]string) {
p.remotes.Flush()
for d, s := range remotes {
p.remotes.Add(d, s)
}
}
// SetWatches sets the domains watched for A and AAAA record updates.
func (p *Proxy) SetWatches(watches []string) {
p.watches.Flush()
for _, d := range watches {
p.watches.Add(d)
}
}
// NewProxy returns a new Proxy that listens on address.
func NewProxy(config *Config) *Proxy {
var udp *dns.Server
if config.ListenUDP {
udp = &dns.Server{
Addr: config.Address,
Net: "udp",
}
}
var tcp *dns.Server
if config.ListenTCP {
tcp = &dns.Server{
Addr: config.Address,
Net: "tcp",
}
}
return &Proxy{
config: config,
udp: udp,
tcp: tcp,
remotes: NewRemotes(),
watches: NewWatches(),
reports: make(chan *Report),
done: make(chan struct{}),
closed: make(chan struct{}),
}
}