Skip to content

Commit

Permalink
Resolve conflicting changes in DNS #309 #341 (#346)
Browse files Browse the repository at this point in the history
Co-authored-by: yuhan6665 <1588741+yuhan6665@users.noreply.github.com>
  • Loading branch information
Jim Han and yuhan6665 authored Mar 7, 2021
1 parent f50eff5 commit d7cd71b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
14 changes: 7 additions & 7 deletions app/dns/dohdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,13 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dispatcherCtx := context.Background()
if inbound := session.InboundFromContext(ctx); inbound != nil {
dispatcherCtx = session.ContextWithInbound(dispatcherCtx, inbound)
}
if content := session.ContentFromContext(ctx); content != nil {
dispatcherCtx = session.ContextWithContent(dispatcherCtx, content)
}
dispatcherCtx = internet.ContextWithLookupDomain(dispatcherCtx, internet.LookupDomainFromContext(ctx))

dest, err := net.ParseDestination(network + ":" + addr)
if err != nil {
return nil, err
}

dispatcherCtx = session.ContextWithContent(dispatcherCtx, &session.Content{Protocol: "tls"})
dispatcherCtx = log.ContextWithAccessMessage(dispatcherCtx, &log.AccessMessage{
From: "DoH",
To: s.dohURL,
Expand All @@ -76,6 +70,12 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.
})

link, err := s.dispatcher.Dispatch(dispatcherCtx, dest)
select {
case <-ctx.Done():
return nil, ctx.Err()
default:

}
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions app/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Server struct {
sync.Mutex
hosts *StaticHosts
clientIP net.IP
clients []Client // clientIdx -> Client
clients []Client // clientIdx -> Client
ctx context.Context
ipIndexMap []*MultiGeoIPMatcher // clientIdx -> *MultiGeoIPMatcher
domainRules [][]string // clientIdx -> domainRuleIdx -> DomainRule
Expand Down Expand Up @@ -307,7 +307,7 @@ func (s *Server) queryIPTimeout(idx int, client Client, domain string, option dn
Tag: s.tag,
})
}
ctx = internet.ContextWithLookupDomain(ctx, Fqdn(domain))
ctx = internet.ContextWithLookupDomain(ctx, domain)
ips, err := client.QueryIP(ctx, domain, option)
cancel()

Expand Down
28 changes: 18 additions & 10 deletions transport/internet/system_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ var (

// InitSystemDialer: It's private method and you are NOT supposed to use this function.
func InitSystemDialer(dc dns.Client, om outbound.Manager) {
effectiveSystemDialer.init(dc, om)
effectiveSystemDialer.Init(dc, om)
}

type SystemDialer interface {
Dial(ctx context.Context, source net.Address, destination net.Destination, sockopt *SocketConfig) (net.Conn, error)
init(dc dns.Client, om outbound.Manager)
Init(dc dns.Client, om outbound.Manager)
}

type DefaultSystemDialer struct {
Expand Down Expand Up @@ -63,22 +63,30 @@ func (d *DefaultSystemDialer) lookupIP(domain string, strategy DomainStrategy, l
return nil, nil
}

var lookup = d.dns.LookupIP
var option = dns.IPOption{
IPv4Enable: true,
IPv6Enable: true,
FakeEnable: false,
}

switch {
case strategy == DomainStrategy_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()):
if lookupIPv4, ok := d.dns.(dns.IPv4Lookup); ok {
lookup = lookupIPv4.LookupIPv4
option = dns.IPOption{
IPv4Enable: true,
IPv6Enable: false,
FakeEnable: false,
}
case strategy == DomainStrategy_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()):
if lookupIPv4, ok := d.dns.(dns.IPv4Lookup); ok {
lookup = lookupIPv4.LookupIPv4
option = dns.IPOption{
IPv4Enable: false,
IPv6Enable: true,
FakeEnable: false,
}
case strategy == DomainStrategy_AS_IS:
return nil, nil
}

return lookup(domain)
return d.dns.LookupIP(domain, option)
}

func (d *DefaultSystemDialer) canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig) bool {
Expand Down Expand Up @@ -184,7 +192,7 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr())
}

func (d *DefaultSystemDialer) init(dc dns.Client, om outbound.Manager) {
func (d *DefaultSystemDialer) Init(dc dns.Client, om outbound.Manager) {
d.dns = dc
d.obm = om
}
Expand Down Expand Up @@ -249,7 +257,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer {
}
}

func (v *SimpleSystemDialer) init(_ dns.Client, _ outbound.Manager) {}
func (v *SimpleSystemDialer) Init(_ dns.Client, _ outbound.Manager) {}

func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr())
Expand Down

0 comments on commit d7cd71b

Please sign in to comment.