-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdns.go
130 lines (107 loc) · 2.89 KB
/
dns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package warc
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
type cachedIP struct {
expiresAt time.Time
ip net.IP
}
const maxFallbackDNSServers = 3
func (d *customDialer) archiveDNS(ctx context.Context, address string) (resolvedIP net.IP, err error) {
// Get the address without the port if there is one
address, _, err = net.SplitHostPort(address)
if err != nil {
return resolvedIP, err
}
// Check if the address is already an IP
resolvedIP = net.ParseIP(address)
if resolvedIP != nil {
return resolvedIP, nil
}
// Check cache first
if cached, ok := d.DNSRecords.Load(address); ok {
cachedEntry := cached.(cachedIP)
if time.Now().Before(cachedEntry.expiresAt) {
return cachedEntry.ip, nil
}
// Cache entry expired, remove it
d.DNSRecords.Delete(address)
}
var wg sync.WaitGroup
var ipv4, ipv6 net.IP
var errA, errAAAA error
if len(d.DNSConfig.Servers) == 0 {
return nil, fmt.Errorf("no DNS servers configured")
}
fallbackServers := min(maxFallbackDNSServers, len(d.DNSConfig.Servers)-1)
for DNSServer := 0; DNSServer <= fallbackServers; DNSServer++ {
if !d.disableIPv4 {
wg.Add(1)
go func() {
defer wg.Done()
ipv4, errA = d.lookupIP(ctx, address, dns.TypeA, DNSServer)
}()
}
if !d.disableIPv6 {
wg.Add(1)
go func() {
defer wg.Done()
ipv6, errAAAA = d.lookupIP(ctx, address, dns.TypeAAAA, DNSServer)
}()
}
wg.Wait()
if errA == nil || errAAAA == nil {
break
}
}
if errA != nil && errAAAA != nil {
return nil, fmt.Errorf("failed to resolve DNS: A error: %v, AAAA error: %v", errA, errAAAA)
}
// Prioritize IPv6 if both are available and enabled
if ipv6 != nil && !d.disableIPv6 {
resolvedIP = ipv6
} else if ipv4 != nil && !d.disableIPv4 {
resolvedIP = ipv4
}
if resolvedIP != nil {
// Cache the result
d.DNSRecords.Store(address, cachedIP{
ip: resolvedIP,
expiresAt: time.Now().Add(d.DNSRecordsTTL),
})
return resolvedIP, nil
}
return nil, fmt.Errorf("no suitable IP address found for %s", address)
}
func (d *customDialer) lookupIP(ctx context.Context, address string, recordType uint16, DNSServer int) (net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(address), recordType)
r, _, err := d.DNSClient.ExchangeContext(ctx, m, net.JoinHostPort(d.DNSConfig.Servers[DNSServer], d.DNSConfig.Port))
if err != nil {
return nil, err
}
// Record the DNS response
recordTypeStr := "TYPE=A"
if recordType == dns.TypeAAAA {
recordTypeStr = "TYPE=AAAA"
}
d.client.WriteRecord(fmt.Sprintf("dns:%s?%s", address, recordTypeStr), "resource", "text/dns", r.String())
for _, answer := range r.Answer {
switch recordType {
case dns.TypeA:
if a, ok := answer.(*dns.A); ok {
return a.A, nil
}
case dns.TypeAAAA:
if aaaa, ok := answer.(*dns.AAAA); ok {
return aaaa.AAAA, nil
}
}
}
return nil, fmt.Errorf("no %s record found", recordTypeStr)
}