From 54969236e119787d27a8f816ac862717e592e6a4 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Fri, 17 Nov 2023 20:20:44 +0700 Subject: [PATCH] Fix authorizeClient on failed refresh Signed-off-by: Artem Glazychev --- pkg/networkservice/common/authorize/client.go | 15 +- .../common/authorize/client_test.go | 136 ++++++++++++++++++ .../common/authorize/metadata.go | 44 ++++++ pkg/tools/opa/opainput.go | 5 - pkg/tools/spiffejwt/token.go | 3 - 5 files changed, 190 insertions(+), 13 deletions(-) create mode 100644 pkg/networkservice/common/authorize/client_test.go create mode 100644 pkg/networkservice/common/authorize/metadata.go diff --git a/pkg/networkservice/common/authorize/client.go b/pkg/networkservice/common/authorize/client.go index 54d348af9..f770feefe 100644 --- a/pkg/networkservice/common/authorize/client.go +++ b/pkg/networkservice/common/authorize/client.go @@ -30,6 +30,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" "github.com/networkservicemesh/sdk/pkg/tools/opa" "github.com/networkservicemesh/sdk/pkg/tools/postpone" ) @@ -84,16 +85,18 @@ func (a *authorizeClient) Request(ctx context.Context, request *networkservice.N } if err = a.policies.check(ctx, conn.GetPath()); err != nil { - closeCtx, cancelClose := postponeCtxFunc() - defer cancelClose() + if !load(ctx, metadata.IsClient(a)) { + closeCtx, cancelClose := postponeCtxFunc() + defer cancelClose() - if _, closeErr := next.Client(ctx).Close(closeCtx, conn, opts...); closeErr != nil { - err = errors.Wrapf(err, "connection closed with error: %s", closeErr.Error()) + if _, closeErr := next.Client(ctx).Close(closeCtx, conn, opts...); closeErr != nil { + err = errors.Wrapf(err, "connection closed with error: %s", closeErr.Error()) + } } - return nil, err } + store(ctx, metadata.IsClient(a)) return conn, nil } @@ -102,6 +105,8 @@ func (a *authorizeClient) Close(ctx context.Context, conn *networkservice.Connec if ok && p != nil { ctx = peer.NewContext(ctx, p) } + del(ctx, metadata.IsClient(a)) + if err := a.policies.check(ctx, conn.GetPath()); err != nil { return nil, err } diff --git a/pkg/networkservice/common/authorize/client_test.go b/pkg/networkservice/common/authorize/client_test.go new file mode 100644 index 000000000..d991f5f95 --- /dev/null +++ b/pkg/networkservice/common/authorize/client_test.go @@ -0,0 +1,136 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authorize_test + +import ( + "context" + "os" + "path" + "path/filepath" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" +) + +func TestAuthClient(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + dir := filepath.Clean(path.Join(os.TempDir(), t.Name())) + defer func() { + _ = os.RemoveAll(dir) + }() + + err := os.MkdirAll(dir, os.ModePerm) + require.Nil(t, err) + + policyPath := filepath.Clean(path.Join(dir, "policy.rego")) + err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm) + require.Nil(t, err) + + suits := []struct { + name string + policyPath string + request *networkservice.NetworkServiceRequest + response *networkservice.Connection + denied bool + }{ + { + name: "simple positive test", + policyPath: policyPath, + request: requestWithToken("allowed"), + denied: false, + }, + { + name: "simple negative test", + policyPath: policyPath, + request: requestWithToken("not_allowed"), + denied: true, + }, + } + + for i := range suits { + s := suits[i] + t.Run(s.name, func(t *testing.T) { + client := chain.NewNetworkServiceClient( + metadata.NewClient(), + authorize.NewClient(authorize.WithPolicies(s.policyPath)), + ) + checkResult := func(err error) { + if !s.denied { + require.Nil(t, err, "request expected to be not denied: ") + return + } + require.NotNil(t, err, "request expected to be denied") + s, ok := status.FromError(errors.Cause(err)) + require.True(t, ok, "error without error status code"+err.Error()) + require.Equal(t, s.Code(), codes.PermissionDenied, "wrong error status code") + } + + _, err := client.Request(context.Background(), s.request) + checkResult(err) + + _, err = client.Close(context.Background(), s.request.GetConnection()) + checkResult(err) + }) + } +} + +func TestAuthClientFailedRefresh(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + dir := filepath.Clean(path.Join(os.TempDir(), t.Name())) + defer func() { + _ = os.RemoveAll(dir) + }() + + err := os.MkdirAll(dir, os.ModePerm) + require.Nil(t, err) + + policyPath := filepath.Clean(path.Join(dir, "policy.rego")) + err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm) + require.Nil(t, err) + + counter := new(count.Client) + client := chain.NewNetworkServiceClient( + metadata.NewClient(), + authorize.NewClient(authorize.WithPolicies(policyPath)), + counter, + ) + + conn, err := client.Request(context.Background(), requestWithToken("allowed")) + require.Nil(t, err) + + refreshRequest := requestWithToken("not_allowed") + _, err = client.Request(context.Background(), refreshRequest) + require.NotNil(t, err) + require.Equal(t, 0, counter.Closes()) + + _, err = client.Close(context.Background(), conn) + require.Nil(t, err) + require.Equal(t, 1, counter.Closes()) +} diff --git a/pkg/networkservice/common/authorize/metadata.go b/pkg/networkservice/common/authorize/metadata.go new file mode 100644 index 000000000..c8f243c4f --- /dev/null +++ b/pkg/networkservice/common/authorize/metadata.go @@ -0,0 +1,44 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authorize + +import ( + "context" + + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" +) + +type key struct{} + +// store sets a flag stored in per Connection.Id metadata. +// It is used to keep a successful Request. +// Based on this, we can understand whether the Request is a refresh. +func store(ctx context.Context, isClient bool) { + metadata.Map(ctx, isClient).Store(key{}, struct{}{}) +} + +// load returns a flag stored in per Connection.Id metadata. +// It is used to determine a refresh +func load(ctx context.Context, isClient bool) (ok bool) { + _, ok = metadata.Map(ctx, isClient).Load(key{}) + return +} + +// del deletes a flag stored in per Connection.Id metadata. +func del(ctx context.Context, isClient bool) { + metadata.Map(ctx, isClient).Delete(key{}) +} diff --git a/pkg/tools/opa/opainput.go b/pkg/tools/opa/opainput.go index 9af6c2209..0a6a9ad7e 100644 --- a/pkg/tools/opa/opainput.go +++ b/pkg/tools/opa/opainput.go @@ -71,11 +71,6 @@ func ParseX509Cert(authInfo credentials.AuthInfo) *x509.Certificate { } } - if tlsInfo, ok := authInfo.(*credentials.TLSInfo); ok { - if len(tlsInfo.State.PeerCertificates) > 0 { - peerCert = tlsInfo.State.PeerCertificates[0] - } - } return peerCert } diff --git a/pkg/tools/spiffejwt/token.go b/pkg/tools/spiffejwt/token.go index 9cc3d8981..d9b09a1f8 100644 --- a/pkg/tools/spiffejwt/token.go +++ b/pkg/tools/spiffejwt/token.go @@ -41,9 +41,6 @@ func TokenGeneratorFunc(source x509svid.Source, maxTokenLifeTime time.Duration) if ownSVID.Certificates[0].NotAfter.Before(expireTime) { expireTime = ownSVID.Certificates[0].NotAfter } - if err != nil { - return "", time.Time{}, errors.Wrap(err, "Error creating Token") - } claims := jwt.RegisteredClaims{ Subject: ownSVID.ID.String(), ExpiresAt: jwt.NewNumericDate(expireTime),