diff --git a/pkg/networkservice/core/next/client.go b/pkg/networkservice/core/next/client.go index 0f5f5dc05..d07083c44 100644 --- a/pkg/networkservice/core/next/client.go +++ b/pkg/networkservice/core/next/client.go @@ -28,8 +28,9 @@ import ( ) type nextClient struct { - clients []networkservice.NetworkServiceClient - index int + clients []networkservice.NetworkServiceClient + index int + nextParent networkservice.NetworkServiceClient } // ClientWrapper - a function that wraps around a networkservice.NetworkServiceClient @@ -40,11 +41,9 @@ type ClientChainer func(...networkservice.NetworkServiceClient) networkservice.N // NewWrappedNetworkServiceClient chains together clients with wrapper wrapped around each one func NewWrappedNetworkServiceClient(wrapper ClientWrapper, clients ...networkservice.NetworkServiceClient) networkservice.NetworkServiceClient { - rv := &nextClient{ - clients: clients, - } - for i := range rv.clients { - rv.clients[i] = wrapper(rv.clients[i]) + rv := &nextClient{clients: make([]networkservice.NetworkServiceClient, 0, len(clients))} + for _, c := range clients { + rv.clients = append(rv.clients, wrapper(c)) } return rv } @@ -54,23 +53,31 @@ func NewNetworkServiceClient(clients ...networkservice.NetworkServiceClient) net if len(clients) == 0 { return &tailClient{} } - return NewWrappedNetworkServiceClient(notWrapClient, clients...) + return NewWrappedNetworkServiceClient(func(client networkservice.NetworkServiceClient) networkservice.NetworkServiceClient { + return client + }, clients...) } func (n *nextClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { + if n.index == 0 && ctx != nil { + if nextParent := Client(ctx); nextParent != nil { + n.nextParent = nextParent + } + } if n.index+1 < len(n.clients) { - return n.clients[n.index].Request(withNextClient(ctx, &nextClient{clients: n.clients, index: n.index + 1}), request, opts...) + return n.clients[n.index].Request(withNextClient(ctx, &nextClient{nextParent: n.nextParent, clients: n.clients, index: n.index + 1}), request, opts...) } - return n.clients[n.index].Request(withNextClient(ctx, nil), request, opts...) + return n.clients[n.index].Request(withNextClient(ctx, n.nextParent), request, opts...) } func (n *nextClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) { + if n.index == 0 && ctx != nil { + if nextParent := Client(ctx); nextParent != nil { + n.nextParent = nextParent + } + } if n.index+1 < len(n.clients) { - return n.clients[n.index].Close(withNextClient(ctx, &nextClient{clients: n.clients, index: n.index + 1}), conn, opts...) + return n.clients[n.index].Close(withNextClient(ctx, &nextClient{nextParent: n.nextParent, clients: n.clients, index: n.index + 1}), conn, opts...) } - return n.clients[n.index].Close(withNextClient(ctx, nil), conn, opts...) -} - -func notWrapClient(c networkservice.NetworkServiceClient) networkservice.NetworkServiceClient { - return c + return n.clients[n.index].Close(withNextClient(ctx, n.nextParent), conn, opts...) } diff --git a/pkg/networkservice/core/next/server.go b/pkg/networkservice/core/next/server.go index 8f2bb75db..990d42629 100644 --- a/pkg/networkservice/core/next/server.go +++ b/pkg/networkservice/core/next/server.go @@ -27,8 +27,9 @@ import ( ) type nextServer struct { - servers []networkservice.NetworkServiceServer - index int + nextParent networkservice.NetworkServiceServer + servers []networkservice.NetworkServiceServer + index int } // ServerWrapper - A function that wraps a networkservice.NetworkServiceServer @@ -39,11 +40,9 @@ type ServerChainer func(...networkservice.NetworkServiceServer) networkservice.N // NewWrappedNetworkServiceServer - chains together the servers provides with the wrapper wrapped around each one in turn. func NewWrappedNetworkServiceServer(wrapper ServerWrapper, servers ...networkservice.NetworkServiceServer) networkservice.NetworkServiceServer { - rv := &nextServer{ - servers: servers, - } - for i := range rv.servers { - rv.servers[i] = wrapper(rv.servers[i]) + rv := &nextServer{servers: make([]networkservice.NetworkServiceServer, 0, len(servers))} + for _, s := range servers { + rv.servers = append(rv.servers, wrapper(s)) } return rv } @@ -54,23 +53,31 @@ func NewNetworkServiceServer(servers ...networkservice.NetworkServiceServer) net if len(servers) == 0 { return &tailServer{} } - return NewWrappedNetworkServiceServer(notWrapServer, servers...) + return NewWrappedNetworkServiceServer(func(server networkservice.NetworkServiceServer) networkservice.NetworkServiceServer { + return server + }, servers...) } func (n *nextServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + if n.index == 0 && ctx != nil { + if nextParent := Server(ctx); nextParent != nil { + n.nextParent = nextParent + } + } if n.index+1 < len(n.servers) { - return n.servers[n.index].Request(withNextServer(ctx, &nextServer{servers: n.servers, index: n.index + 1}), request) + return n.servers[n.index].Request(withNextServer(ctx, &nextServer{nextParent: n.nextParent, servers: n.servers, index: n.index + 1}), request) } - return n.servers[n.index].Request(withNextServer(ctx, nil), request) + return n.servers[n.index].Request(withNextServer(ctx, n.nextParent), request) } func (n *nextServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { + if n.index == 0 && ctx != nil { + if nextParent := Server(ctx); nextParent != nil { + n.nextParent = nextParent + } + } if n.index+1 < len(n.servers) { - return n.servers[n.index].Close(withNextServer(ctx, &nextServer{servers: n.servers, index: n.index + 1}), conn) + return n.servers[n.index].Close(withNextServer(ctx, &nextServer{nextParent: n.nextParent, servers: n.servers, index: n.index + 1}), conn) } - return n.servers[n.index].Close(withNextServer(ctx, nil), conn) -} - -func notWrapServer(server networkservice.NetworkServiceServer) networkservice.NetworkServiceServer { - return server + return n.servers[n.index].Close(withNextServer(ctx, n.nextParent), conn) } diff --git a/pkg/networkservice/core/next/tests/client_test.go b/pkg/networkservice/core/next/tests/client_test.go index 8c36b0722..36a10fd78 100644 --- a/pkg/networkservice/core/next/tests/client_test.go +++ b/pkg/networkservice/core/next/tests/client_test.go @@ -48,8 +48,10 @@ func TestClientBranches(t *testing.T) { {visitClient(), visitClient(), adapters.NewServerToClient(visitServer())}, {visitClient(), adapters.NewServerToClient(emptyServer()), visitClient()}, {visitClient(), adapters.NewServerToClient(visitServer()), emptyClient()}, + {visitClient(), next.NewNetworkServiceClient(next.NewNetworkServiceClient(visitClient())), visitClient()}, + {visitClient(), next.NewNetworkServiceClient(next.NewNetworkServiceClient(emptyClient())), visitClient()}, } - expects := []int{1, 2, 3, 0, 1, 2, 1, 2, 3, 3, 1, 2} + expects := []int{1, 2, 3, 0, 1, 2, 1, 2, 3, 3, 1, 2, 3, 1} for i, sample := range samples { msg := fmt.Sprintf("sample index: %v", i) ctx := visit(context.Background()) diff --git a/pkg/networkservice/core/next/tests/server_test.go b/pkg/networkservice/core/next/tests/server_test.go index ceca795b4..c6fd113de 100644 --- a/pkg/networkservice/core/next/tests/server_test.go +++ b/pkg/networkservice/core/next/tests/server_test.go @@ -50,8 +50,10 @@ func TestServerBranches(t *testing.T) { {visitServer(), visitServer(), adapters.NewClientToServer(visitClient())}, {visitServer(), adapters.NewClientToServer(emptyClient()), visitServer()}, {visitServer(), adapters.NewClientToServer(visitClient()), emptyServer()}, + {visitServer(), next.NewNetworkServiceServer(next.NewNetworkServiceServer(visitServer())), visitServer()}, + {visitServer(), next.NewNetworkServiceServer(next.NewNetworkServiceServer(emptyServer())), visitServer()}, } - expects := []int{1, 2, 3, 0, 1, 2, 1, 2, 3, 3, 1, 2} + expects := []int{1, 2, 3, 0, 1, 2, 1, 2, 3, 3, 1, 2, 3, 1} for i, sample := range servers { ctx := visit(context.Background()) s := next.NewNetworkServiceServer(sample...)