diff --git a/pkg/networkservice/chains/nsmgr/vl3_test.go b/pkg/networkservice/chains/nsmgr/vl3_test.go index 000900baf..7583e16f7 100644 --- a/pkg/networkservice/chains/nsmgr/vl3_test.go +++ b/pkg/networkservice/chains/nsmgr/vl3_test.go @@ -23,7 +23,6 @@ import ( "context" "fmt" "net" - "net/url" "testing" "time" @@ -40,6 +39,12 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/sandbox" ) +func staticIP(addr net.IP) func() net.IP { + return func() net.IP { + return addr + } +} + func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -70,7 +75,7 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { sandbox.GenerateTestToken, vl3.NewServer(ctx, serverPrefixCh), vl3dns.NewServer(ctx, - &url.URL{Scheme: "tcp", Host: "127.0.0.1"}, + staticIP(net.ParseIP("127.0.0.1")), vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), vl3dns.WithDNSPort(40053)), ) @@ -161,7 +166,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) { sandbox.GenerateTestToken, vl3.NewServer(ctx, serverPrefixCh), vl3dns.NewServer(ctx, - &url.URL{Scheme: "tcp", Host: "127.0.0.1"}, + staticIP(net.ParseIP("127.0.0.1")), vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), vl3dns.WithDNSListenAndServeFunc(func(ctx context.Context, handler dnsutils.Handler, listenOn string) { dnsutils.ListenAndServe(ctx, handler, ":50053") @@ -182,7 +187,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) { defer close(clientPrefixCh) clientPrefixCh <- &ipam.PrefixResponse{Prefix: "127.0.0.1/32"} - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(vl3.NewClient(ctx, clientPrefixCh), vl3dns.NewClient(&url.URL{Host: "127.0.0.1"}))) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(vl3.NewClient(ctx, clientPrefixCh), vl3dns.NewClient(net.ParseIP("127.0.0.1")))) req := defaultRequest(nsReg.Name) req.Connection.Id = uuid.New().String() diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go index 1864c6cc5..ed016fad8 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go @@ -18,7 +18,7 @@ package vl3dns import ( "context" - "net/url" + "net" "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/networkservice" @@ -28,14 +28,14 @@ import ( ) type vl3DNSClient struct { - listenOn *url.URL + dnsServerIP net.IP } // NewClient - returns a new null client that does nothing but call next.Client(ctx).{Request/Close} and return the result // This is very useful in testing -func NewClient(listenOn *url.URL) networkservice.NetworkServiceClient { +func NewClient(dnsServerIP net.IP) networkservice.NetworkServiceClient { return &vl3DNSClient{ - listenOn: listenOn, + dnsServerIP: dnsServerIP, } } @@ -52,7 +52,7 @@ func (n *vl3DNSClient) Request(ctx context.Context, request *networkservice.Netw request.GetConnection().GetContext().GetDnsContext().Configs = []*networkservice.DNSConfig{ { - DnsServerIps: []string{n.listenOn.Hostname()}, + DnsServerIps: []string{n.dnsServerIP.String()}, }, } return next.Client(ctx).Request(ctx, request, opts...) diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go index e6cd50ad7..d89a901e7 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go @@ -37,6 +37,7 @@ import ( dnsnext "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/next" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/noloop" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/norecursion" + "github.com/networkservicemesh/sdk/pkg/tools/ippool" ) type vl3DNSServer struct { @@ -47,7 +48,7 @@ type vl3DNSServer struct { dnsPort int dnsServer dnsutils.Handler listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string) - listenOn *url.URL + getDNSServerIP func() net.IP } type clientDNSNameKey struct{} @@ -57,11 +58,11 @@ type clientDNSNameKey struct{} // By default is using fanout dns handler to connect to other vl3 nses. // chanCtx is using for signal to stop dns server. // opts confugre vl3dns networkservice instance with specific behavior. -func NewServer(chanCtx context.Context, listenOn *url.URL, opts ...Option) networkservice.NetworkServiceServer { +func NewServer(chanCtx context.Context, getDNSServerIP func() net.IP, opts ...Option) networkservice.NetworkServiceServer { var result = &vl3DNSServer{ dnsPort: 53, listenAndServeDNS: dnsutils.ListenAndServe, - listenOn: listenOn, + getDNSServerIP: getDNSServerIP, } for _, opt := range opts { @@ -89,15 +90,20 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw var dnsContext = request.GetConnection().GetContext().GetDnsContext() - for _, config := range dnsContext.GetConfigs() { - for _, serverIP := range config.DnsServerIps { - var u = url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", serverIP, n.dnsPort)} - n.fanoutAddresses.Store(u, struct{}{}) + if srcRoutes := request.GetConnection().GetContext().GetIpContext().GetSrcIPRoutes(); len(srcRoutes) > 0 { + var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix + for _, config := range dnsContext.GetConfigs() { + for _, serverIP := range config.DnsServerIps { + if withinPrefix(serverIP, lastPrefix) { + var u = url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", serverIP, n.dnsPort)} + n.fanoutAddresses.Store(u, struct{}{}) + } + } } } dnsContext.Configs = append(dnsContext.Configs, &networkservice.DNSConfig{ - DnsServerIps: []string{n.listenOn.Hostname()}, + DnsServerIps: []string{n.getDNSServerIP().String()}, }) var recordNames, err = n.buildSrcDNSRecords(request.GetConnection()) @@ -178,3 +184,12 @@ func compareStringSlices(a, b []string) bool { } return true } + +func withinPrefix(ipAddr, prefix string) bool { + _, ipNet, err := net.ParseCIDR(prefix) + if err != nil { + return false + } + var pool = ippool.NewWithNet(ipNet) + return pool.ContainsString(ipAddr) +}