Skip to content

Commit

Permalink
additional fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
  • Loading branch information
glazychev-art committed Feb 6, 2023
1 parent c8e9bd0 commit 8e4f139
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 49 deletions.
16 changes: 9 additions & 7 deletions pkg/networkservice/chains/nsmgr/vl3_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -227,7 +227,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

domain := sandbox.NewBuilder(ctx, t).
Expand Down Expand Up @@ -277,14 +277,16 @@ func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
}),
))

reqCtx, reqClose := context.WithTimeout(ctx, time.Second*10)
defer reqClose()

req := defaultRequest(nsReg.Name)
_, err = nsc.Request(reqCtx, req)
req.Connection.Labels["podName"] = nscName
_, err = nsc.Request(ctx, req)
require.NoError(t, err)

dnsServerIPCh <- net.ParseIP("127.0.0.1")

<-refreshCompletedCh
select {
case <-ctx.Done():
case <-refreshCompletedCh:
}
require.NoError(t, ctx.Err())
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020 Doc.ai and/or its affiliates.
//
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down
81 changes: 40 additions & 41 deletions pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (
"net"
"strings"
"sync"
"sync/atomic"
"text/template"

"github.com/edwarnicke/serialize"
"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"

Expand All @@ -45,16 +45,16 @@ import (
)

type vl3DNSServer struct {
chanCtx context.Context
chainCtx context.Context
dnsServerRecords memory.Map
dnsConfigs *dnsconfig.Map
domainSchemeTemplates []*template.Template
dnsPort int
dnsServer dnsutils.Handler
listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string)
dnsServerIP net.IP
serverAddressCh <-chan net.IP
executor serialize.Executor
dnsServerIP atomic.Value
dnsServerIPCh <-chan net.IP
monitorEventConsumer monitor.EventConsumer
once sync.Once
}

Expand All @@ -63,15 +63,15 @@ type clientDNSNameKey struct{}
// NewServer creates a new vl3dns netwrokservice server.
// It starts dns server on the passed port/url. By default listens ":53".
// By default is using fanout dns handler to connect to other vl3 nses.
// chanCtx is using for signal to stop dns server.
// opts confugre vl3dns networkservice instance with specific behavior.
func NewServer(chanCtx context.Context, serverAddressCh <-chan net.IP, opts ...Option) networkservice.NetworkServiceServer {
// chainCtx is using for signal to stop dns server.
// opts configure vl3dns networkservice instance with specific behavior.
func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Option) networkservice.NetworkServiceServer {
var result = &vl3DNSServer{
chanCtx: chanCtx,
chainCtx: chainCtx,
dnsPort: 53,
listenAndServeDNS: dnsutils.ListenAndServe,
dnsConfigs: new(dnsconfig.Map),
serverAddressCh: serverAddressCh,
dnsServerIPCh: dnsServerIPCh,
}

for _, opt := range opts {
Expand All @@ -88,16 +88,18 @@ func NewServer(chanCtx context.Context, serverAddressCh <-chan net.IP, opts ...O
)
}

result.listenAndServeDNS(chanCtx, result.dnsServer, fmt.Sprintf(":%v", result.dnsPort))
result.listenAndServeDNS(chainCtx, result.dnsServer, fmt.Sprintf(":%v", result.dnsPort))

if len(serverAddressCh) > 0 {
result.dnsServerIP = <-serverAddressCh
if len(dnsServerIPCh) > 0 {
result.dnsServerIP.Store(<-dnsServerIPCh)
}
return result
}

func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
n.once.Do(func() {
// We assume here that the monitorEventConsumer is the same for all connections.
// We need the context of any request to pull it out.
go n.checkServerAddressUpdates(ctx)
})

Expand Down Expand Up @@ -125,9 +127,11 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw

ips := getSrcIPs(request.GetConnection())
if len(ips) > 0 {
<-n.executor.AsyncExec(func() {
n.storeSrcIPs(ctx, recordNames, ips)
})
for _, recordName := range recordNames {
n.dnsServerRecords.Store(recordName, ips)
}

metadata.Map(ctx, false).Store(clientDNSNameKey{}, recordNames)
}

resp, err := next.Server(ctx).Request(ctx, request)
Expand All @@ -149,7 +153,6 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw
}
n.dnsConfigs.Store(resp.GetId(), configs)
}

return resp, err
}

Expand All @@ -167,7 +170,8 @@ func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connectio
}

func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, ok bool) {
if dnsServerIP := n.dnsServerIP; dnsServerIP != nil {
if ip := n.dnsServerIP.Load(); ip != nil {
dnsServerIP := ip.(net.IP)
var dnsContext = c.GetContext().GetDnsContext()
configToAdd := &networkservice.DNSConfig{
DnsServerIps: []string{dnsServerIP.String()},
Expand All @@ -193,42 +197,37 @@ func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]strin
}

func (n *vl3DNSServer) checkServerAddressUpdates(ctx context.Context) {
n.monitorEventConsumer, _ = monitor.LoadEventConsumer(ctx, metadata.IsClient(n))
for {
select {
case <-n.chanCtx.Done():
case <-n.chainCtx.Done():
return
case addr, ok := <-n.serverAddressCh:
case addr, ok := <-n.dnsServerIPCh:
if !ok {
return
}

n.updateServerAddress(ctx, addr)
n.updateServerAddress(addr)
}
}
}

func (n *vl3DNSServer) updateServerAddress(ctx context.Context, address net.IP) {
<-n.executor.AsyncExec(func() {
n.dnsServerIP = address
logger := log.FromContext(ctx).WithField("vl3DNSServer", "Request")

if eventConsumer, ok := monitor.LoadEventConsumer(ctx, metadata.IsClient(n)); ok {
_ = eventConsumer.Send(&networkservice.ConnectionEvent{
Type: networkservice.ConnectionEventType_UPDATE,
Connections: eventConsumer.GetConnections(),
})
} else {
logger.Debug("eventConsumer is not presented")
}
})
}
func (n *vl3DNSServer) updateServerAddress(address net.IP) {
n.dnsServerIP.Store(address)

func (n *vl3DNSServer) storeSrcIPs(ctx context.Context, recordNames []string, ips []net.IP) {
for _, recordName := range recordNames {
n.dnsServerRecords.Store(recordName, ips)
if n.monitorEventConsumer != nil {
conns := n.monitorEventConsumer.GetConnections()
for _, c := range conns {
c.State = networkservice.State_REFRESH_REQUESTED
}
_ = n.monitorEventConsumer.Send(&networkservice.ConnectionEvent{
Type: networkservice.ConnectionEventType_UPDATE,
Connections: conns,
})
} else {
log.FromContext(n.chainCtx).WithField("vl3DNSServer", "updateServerAddress").
Debug("eventConsumer is not presented")
}

metadata.Map(ctx, false).Store(clientDNSNameKey{}, recordNames)
}

func compareStringSlices(a, b []string) bool {
Expand Down

0 comments on commit 8e4f139

Please sign in to comment.