diff --git a/pkg/tools/dnsutils/dnsconfigs/handler.go b/pkg/tools/dnsutils/dnsconfigs/handler.go index 05c05d3fe8..dd13fd8823 100644 --- a/pkg/tools/dnsutils/dnsconfigs/handler.go +++ b/pkg/tools/dnsutils/dnsconfigs/handler.go @@ -30,22 +30,21 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/dnsutils" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/next" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/searches" + "github.com/networkservicemesh/sdk/pkg/tools/log" ) type dnsConfigsHandler struct { configs *dnsconfig.Map } -func (h *dnsConfigsHandler) ServeDNS(ctx context.Context, rp dns.ResponseWriter, m *dns.Msg) { +func (h *dnsConfigsHandler) ServeDNS(ctx context.Context, rw dns.ResponseWriter, m *dns.Msg) { dnsIPs := make([]url.URL, 0) searchDomains := make([]string, 0) h.configs.Range(func(key string, value []*networkservice.DNSConfig) bool { for _, conf := range value { for _, ip := range conf.DnsServerIps { - dnsIPs = append(dnsIPs, - url.URL{Scheme: "udp", Host: ip}, - url.URL{Scheme: "tcp", Host: ip}) + dnsIPs = append(dnsIPs, url.URL{Scheme: "udp", Host: ip}) } searchDomains = append(searchDomains, conf.SearchDomains...) } @@ -55,7 +54,36 @@ func (h *dnsConfigsHandler) ServeDNS(ctx context.Context, rp dns.ResponseWriter, ctx = clienturlctx.WithClientURLs(ctx, dnsIPs) ctx = searches.WithSearchDomains(ctx, searchDomains) - next.Handler(ctx).ServeDNS(ctx, rp, m) + + udpRW := &responseWriter{Response: nil} + next.Handler(ctx).ServeDNS(ctx, udpRW, m) + + if resp := udpRW.Response; resp != nil { + if err := rw.WriteMsg(resp); err != nil { + log.FromContext(ctx).WithField("dnsConfigHandler", "ServeDNS").Warnf("got an error during writing the message: %v", err.Error()) + dns.HandleFailed(rw, resp) + return + } + return + } + + for i := range dnsIPs { + dnsIPs[i].Scheme = "tcp" + } + + tcpRW := &responseWriter{Response: nil} + next.Handler(ctx).ServeDNS(ctx, tcpRW, m) + + if resp := tcpRW.Response; resp != nil { + if err := rw.WriteMsg(resp); err != nil { + log.FromContext(ctx).WithField("dnsConfigHandler", "ServeDNS").Warnf("got an error during writing the message: %v", err.Error()) + dns.HandleFailed(rw, resp) + return + } + return + } + + dns.HandleFailed(rw, m) } // NewDNSHandler creates a new dns handler that stores DNS configs diff --git a/pkg/tools/dnsutils/dnsconfigs/response_writer.go b/pkg/tools/dnsutils/dnsconfigs/response_writer.go new file mode 100644 index 0000000000..7e396fc389 --- /dev/null +++ b/pkg/tools/dnsutils/dnsconfigs/response_writer.go @@ -0,0 +1,31 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dnsconfigs + +import ( + "github.com/miekg/dns" +) + +type responseWriter struct { + dns.ResponseWriter + Response *dns.Msg +} + +func (r *responseWriter) WriteMsg(m *dns.Msg) error { + r.Response = m + return nil +} diff --git a/pkg/tools/dnsutils/searches/handler.go b/pkg/tools/dnsutils/searches/handler.go index 2ae7dcdff6..3b4dae28d3 100644 --- a/pkg/tools/dnsutils/searches/handler.go +++ b/pkg/tools/dnsutils/searches/handler.go @@ -19,7 +19,6 @@ package searches import ( "context" - "time" "github.com/miekg/dns" @@ -28,10 +27,6 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/log" ) -const ( - timeout = 5 * time.Second -) - type searchDomainsHandler struct { } @@ -41,12 +36,9 @@ func (h *searchDomainsHandler) ServeDNS(ctx context.Context, rw dns.ResponseWrit r := &responseWriter{ ResponseWriter: rw, Responses: make([]*dns.Msg, len(domains)+1), - index: 0, + Index: 0, } - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - next.Handler(ctx).ServeDNS(ctx, r, m) for _, d := range SearchDomains(ctx) { @@ -71,6 +63,6 @@ func (h *searchDomainsHandler) ServeDNS(ctx context.Context, rw dns.ResponseWrit } // NewDNSHandler creates a new dns handler that makes requests to all subdomains received from dns configs -func NewDNSHandler() dnsutils.Handler { - return new(searchDomainsHandler) +func NewDNSHandler(domains []string) dnsutils.Handler { + return &searchDomainsHandler{SearchDomains: domains} } diff --git a/pkg/tools/dnsutils/searches/response_writer.go b/pkg/tools/dnsutils/searches/response_writer.go index acbe6c787d..253ad7649b 100644 --- a/pkg/tools/dnsutils/searches/response_writer.go +++ b/pkg/tools/dnsutils/searches/response_writer.go @@ -23,11 +23,11 @@ import ( type responseWriter struct { dns.ResponseWriter Responses []*dns.Msg - index int + Index int } func (r *responseWriter) WriteMsg(m *dns.Msg) error { - r.Responses[r.index] = m - r.index++ + r.Responses[r.Index] = m + r.Index++ return nil }