From 774304c70923224e2f011ecc026428c929837271 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Thu, 7 Dec 2023 22:14:43 +0700 Subject: [PATCH] Fix authorizeClient on failed refresh (#1558) * Fix authorizeClient on failed refresh Signed-off-by: Artem Glazychev * Fix registry Signed-off-by: Artem Glazychev --------- Signed-off-by: Artem Glazychev --- pkg/networkservice/common/authorize/client.go | 15 +- .../common/authorize/client_test.go | 136 ++++++++++++++++++ .../common/authorize/metadata.go | 44 ++++++ pkg/registry/common/authorize/ns_client.go | 10 +- .../common/authorize/ns_client_test.go | 18 ++- pkg/registry/common/authorize/nse_client.go | 10 +- .../common/authorize/nse_client_test.go | 18 ++- pkg/tools/opa/opainput.go | 5 - pkg/tools/spiffejwt/token.go | 3 - 9 files changed, 232 insertions(+), 27 deletions(-) create mode 100644 pkg/networkservice/common/authorize/client_test.go create mode 100644 pkg/networkservice/common/authorize/metadata.go diff --git a/pkg/networkservice/common/authorize/client.go b/pkg/networkservice/common/authorize/client.go index 54d348af9..f770feefe 100644 --- a/pkg/networkservice/common/authorize/client.go +++ b/pkg/networkservice/common/authorize/client.go @@ -30,6 +30,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" "github.com/networkservicemesh/sdk/pkg/tools/opa" "github.com/networkservicemesh/sdk/pkg/tools/postpone" ) @@ -84,16 +85,18 @@ func (a *authorizeClient) Request(ctx context.Context, request *networkservice.N } if err = a.policies.check(ctx, conn.GetPath()); err != nil { - closeCtx, cancelClose := postponeCtxFunc() - defer cancelClose() + if !load(ctx, metadata.IsClient(a)) { + closeCtx, cancelClose := postponeCtxFunc() + defer cancelClose() - if _, closeErr := next.Client(ctx).Close(closeCtx, conn, opts...); closeErr != nil { - err = errors.Wrapf(err, "connection closed with error: %s", closeErr.Error()) + if _, closeErr := next.Client(ctx).Close(closeCtx, conn, opts...); closeErr != nil { + err = errors.Wrapf(err, "connection closed with error: %s", closeErr.Error()) + } } - return nil, err } + store(ctx, metadata.IsClient(a)) return conn, nil } @@ -102,6 +105,8 @@ func (a *authorizeClient) Close(ctx context.Context, conn *networkservice.Connec if ok && p != nil { ctx = peer.NewContext(ctx, p) } + del(ctx, metadata.IsClient(a)) + if err := a.policies.check(ctx, conn.GetPath()); err != nil { return nil, err } diff --git a/pkg/networkservice/common/authorize/client_test.go b/pkg/networkservice/common/authorize/client_test.go new file mode 100644 index 000000000..d991f5f95 --- /dev/null +++ b/pkg/networkservice/common/authorize/client_test.go @@ -0,0 +1,136 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authorize_test + +import ( + "context" + "os" + "path" + "path/filepath" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" +) + +func TestAuthClient(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + dir := filepath.Clean(path.Join(os.TempDir(), t.Name())) + defer func() { + _ = os.RemoveAll(dir) + }() + + err := os.MkdirAll(dir, os.ModePerm) + require.Nil(t, err) + + policyPath := filepath.Clean(path.Join(dir, "policy.rego")) + err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm) + require.Nil(t, err) + + suits := []struct { + name string + policyPath string + request *networkservice.NetworkServiceRequest + response *networkservice.Connection + denied bool + }{ + { + name: "simple positive test", + policyPath: policyPath, + request: requestWithToken("allowed"), + denied: false, + }, + { + name: "simple negative test", + policyPath: policyPath, + request: requestWithToken("not_allowed"), + denied: true, + }, + } + + for i := range suits { + s := suits[i] + t.Run(s.name, func(t *testing.T) { + client := chain.NewNetworkServiceClient( + metadata.NewClient(), + authorize.NewClient(authorize.WithPolicies(s.policyPath)), + ) + checkResult := func(err error) { + if !s.denied { + require.Nil(t, err, "request expected to be not denied: ") + return + } + require.NotNil(t, err, "request expected to be denied") + s, ok := status.FromError(errors.Cause(err)) + require.True(t, ok, "error without error status code"+err.Error()) + require.Equal(t, s.Code(), codes.PermissionDenied, "wrong error status code") + } + + _, err := client.Request(context.Background(), s.request) + checkResult(err) + + _, err = client.Close(context.Background(), s.request.GetConnection()) + checkResult(err) + }) + } +} + +func TestAuthClientFailedRefresh(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + dir := filepath.Clean(path.Join(os.TempDir(), t.Name())) + defer func() { + _ = os.RemoveAll(dir) + }() + + err := os.MkdirAll(dir, os.ModePerm) + require.Nil(t, err) + + policyPath := filepath.Clean(path.Join(dir, "policy.rego")) + err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm) + require.Nil(t, err) + + counter := new(count.Client) + client := chain.NewNetworkServiceClient( + metadata.NewClient(), + authorize.NewClient(authorize.WithPolicies(policyPath)), + counter, + ) + + conn, err := client.Request(context.Background(), requestWithToken("allowed")) + require.Nil(t, err) + + refreshRequest := requestWithToken("not_allowed") + _, err = client.Request(context.Background(), refreshRequest) + require.NotNil(t, err) + require.Equal(t, 0, counter.Closes()) + + _, err = client.Close(context.Background(), conn) + require.Nil(t, err) + require.Equal(t, 1, counter.Closes()) +} diff --git a/pkg/networkservice/common/authorize/metadata.go b/pkg/networkservice/common/authorize/metadata.go new file mode 100644 index 000000000..ca10b5c5c --- /dev/null +++ b/pkg/networkservice/common/authorize/metadata.go @@ -0,0 +1,44 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authorize + +import ( + "context" + + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" +) + +type key struct{} + +// store sets a flag stored per Connection.Id metadata. +// It is used to keep a successful Request. +// Based on this, we can understand whether the Request is a refresh. +func store(ctx context.Context, isClient bool) { + metadata.Map(ctx, isClient).Store(key{}, struct{}{}) +} + +// load returns a flag stored per Connection.Id metadata. +// It is used to determine a refresh +func load(ctx context.Context, isClient bool) (ok bool) { + _, ok = metadata.Map(ctx, isClient).Load(key{}) + return +} + +// del deletes a flag stored per Connection.Id metadata. +func del(ctx context.Context, isClient bool) { + metadata.Map(ctx, isClient).Delete(key{}) +} diff --git a/pkg/registry/common/authorize/ns_client.go b/pkg/registry/common/authorize/ns_client.go index e6b8e7876..7a24454be 100644 --- a/pkg/registry/common/authorize/ns_client.go +++ b/pkg/registry/common/authorize/ns_client.go @@ -89,11 +89,13 @@ func (c *authorizeNSClient) Register(ctx context.Context, ns *registry.NetworkSe Index: path.Index, } if err := c.policies.check(ctx, input); err != nil { - unregisterCtx, cancelUnregister := postponeCtxFunc() - defer cancelUnregister() + if _, load := c.nsPathIdsMap.Load(resp.Name); !load { + unregisterCtx, cancelUnregister := postponeCtxFunc() + defer cancelUnregister() - if _, unregisterErr := next.NetworkServiceRegistryClient(ctx).Unregister(unregisterCtx, resp, opts...); unregisterErr != nil { - err = errors.Wrapf(err, "nse unregistered with error: %s", unregisterErr.Error()) + if _, unregisterErr := next.NetworkServiceRegistryClient(ctx).Unregister(unregisterCtx, resp, opts...); unregisterErr != nil { + err = errors.Wrapf(err, "nse unregistered with error: %s", unregisterErr.Error()) + } } return nil, err diff --git a/pkg/registry/common/authorize/ns_client_test.go b/pkg/registry/common/authorize/ns_client_test.go index cf676a6f5..73d0da7d6 100644 --- a/pkg/registry/common/authorize/ns_client_test.go +++ b/pkg/registry/common/authorize/ns_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. +// Copyright (c) 2022-2023 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -25,6 +25,8 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/utils/count" "go.uber.org/goleak" ) @@ -32,8 +34,11 @@ import ( func TestNSRegistryAuthorizeClient(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - client := authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")) - require.NotNil(t, client) + var callCounter = &count.CallCounter{} + client := chain.NewNetworkServiceRegistryClient( + authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")), + count.NewNetworkServiceRegistryClient(callCounter), + ) ns := ®istry.NetworkService{Name: "ns"} path1 := getPath(t, spiffeid1) @@ -45,20 +50,27 @@ func TestNSRegistryAuthorizeClient(t *testing.T) { ns.PathIds = []string{spiffeid1} _, err := client.Register(ctx1, ns) require.NoError(t, err) + require.Equal(t, callCounter.Registers(), 1) ns.PathIds = []string{spiffeid2} _, err = client.Register(ctx2, ns) require.Error(t, err) + require.Equal(t, callCounter.Registers(), 2) + require.Equal(t, callCounter.Unregisters(), 0) ns.PathIds = []string{spiffeid1} _, err = client.Register(ctx1, ns) require.NoError(t, err) + require.Equal(t, callCounter.Registers(), 3) + require.Equal(t, callCounter.Unregisters(), 0) ns.PathIds = []string{spiffeid2} _, err = client.Unregister(ctx2, ns) require.Error(t, err) + require.Equal(t, callCounter.Unregisters(), 1) ns.PathIds = []string{spiffeid1} _, err = client.Unregister(ctx1, ns) require.NoError(t, err) + require.Equal(t, callCounter.Unregisters(), 2) } diff --git a/pkg/registry/common/authorize/nse_client.go b/pkg/registry/common/authorize/nse_client.go index e3a1822d9..70b324462 100644 --- a/pkg/registry/common/authorize/nse_client.go +++ b/pkg/registry/common/authorize/nse_client.go @@ -88,11 +88,13 @@ func (c *authorizeNSEClient) Register(ctx context.Context, nse *registry.Network Index: path.Index, } if err := c.policies.check(ctx, input); err != nil { - unregisterCtx, cancelUnregister := postponeCtxFunc() - defer cancelUnregister() + if _, load := c.nsePathIdsMap.Load(resp.Name); !load { + unregisterCtx, cancelUnregister := postponeCtxFunc() + defer cancelUnregister() - if _, unregisterErr := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(unregisterCtx, resp, opts...); unregisterErr != nil { - err = errors.Wrapf(err, "nse unregistered with error: %s", unregisterErr.Error()) + if _, unregisterErr := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(unregisterCtx, resp, opts...); unregisterErr != nil { + err = errors.Wrapf(err, "nse unregistered with error: %s", unregisterErr.Error()) + } } return nil, err diff --git a/pkg/registry/common/authorize/nse_client_test.go b/pkg/registry/common/authorize/nse_client_test.go index 54b89f87c..595ae9692 100644 --- a/pkg/registry/common/authorize/nse_client_test.go +++ b/pkg/registry/common/authorize/nse_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. +// Copyright (c) 2022-2023 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -25,6 +25,8 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/utils/count" "go.uber.org/goleak" ) @@ -32,8 +34,11 @@ import ( func TestNSERegistryAuthorizeClient(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - client := authorize.NewNetworkServiceEndpointRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")) - require.NotNil(t, client) + var callCounter = &count.CallCounter{} + client := chain.NewNetworkServiceEndpointRegistryClient( + authorize.NewNetworkServiceEndpointRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")), + count.NewNetworkServiceEndpointRegistryClient(callCounter), + ) nse := ®istry.NetworkServiceEndpoint{Name: "nse"} path1 := getPath(t, spiffeid1) @@ -45,20 +50,27 @@ func TestNSERegistryAuthorizeClient(t *testing.T) { nse.PathIds = []string{spiffeid1} _, err := client.Register(ctx1, nse) require.NoError(t, err) + require.Equal(t, callCounter.Registers(), 1) nse.PathIds = []string{spiffeid2} _, err = client.Register(ctx2, nse) require.Error(t, err) + require.Equal(t, callCounter.Registers(), 2) + require.Equal(t, callCounter.Unregisters(), 0) nse.PathIds = []string{spiffeid1} _, err = client.Register(ctx1, nse) require.NoError(t, err) + require.Equal(t, callCounter.Registers(), 3) + require.Equal(t, callCounter.Unregisters(), 0) nse.PathIds = []string{spiffeid2} _, err = client.Unregister(ctx2, nse) require.Error(t, err) + require.Equal(t, callCounter.Unregisters(), 1) nse.PathIds = []string{spiffeid1} _, err = client.Unregister(ctx1, nse) require.NoError(t, err) + require.Equal(t, callCounter.Unregisters(), 2) } diff --git a/pkg/tools/opa/opainput.go b/pkg/tools/opa/opainput.go index 9af6c2209..0a6a9ad7e 100644 --- a/pkg/tools/opa/opainput.go +++ b/pkg/tools/opa/opainput.go @@ -71,11 +71,6 @@ func ParseX509Cert(authInfo credentials.AuthInfo) *x509.Certificate { } } - if tlsInfo, ok := authInfo.(*credentials.TLSInfo); ok { - if len(tlsInfo.State.PeerCertificates) > 0 { - peerCert = tlsInfo.State.PeerCertificates[0] - } - } return peerCert } diff --git a/pkg/tools/spiffejwt/token.go b/pkg/tools/spiffejwt/token.go index 9cc3d8981..d9b09a1f8 100644 --- a/pkg/tools/spiffejwt/token.go +++ b/pkg/tools/spiffejwt/token.go @@ -41,9 +41,6 @@ func TokenGeneratorFunc(source x509svid.Source, maxTokenLifeTime time.Duration) if ownSVID.Certificates[0].NotAfter.Before(expireTime) { expireTime = ownSVID.Certificates[0].NotAfter } - if err != nil { - return "", time.Time{}, errors.Wrap(err, "Error creating Token") - } claims := jwt.RegisteredClaims{ Subject: ownSVID.ID.String(), ExpiresAt: jwt.NewNumericDate(expireTime),