Skip to content

Commit

Permalink
Fix authorizeClient on failed refresh
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
  • Loading branch information
glazychev-art committed Nov 17, 2023
1 parent 7e78948 commit e3a83f6
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 13 deletions.
15 changes: 10 additions & 5 deletions pkg/networkservice/common/authorize/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
"github.com/networkservicemesh/sdk/pkg/tools/opa"
"github.com/networkservicemesh/sdk/pkg/tools/postpone"
)
Expand Down Expand Up @@ -84,16 +85,18 @@ func (a *authorizeClient) Request(ctx context.Context, request *networkservice.N
}

if err = a.policies.check(ctx, conn.GetPath()); err != nil {
closeCtx, cancelClose := postponeCtxFunc()
defer cancelClose()
if !load(ctx, metadata.IsClient(a)) {
closeCtx, cancelClose := postponeCtxFunc()
defer cancelClose()

if _, closeErr := next.Client(ctx).Close(closeCtx, conn, opts...); closeErr != nil {
err = errors.Wrapf(err, "connection closed with error: %s", closeErr.Error())
if _, closeErr := next.Client(ctx).Close(closeCtx, conn, opts...); closeErr != nil {
err = errors.Wrapf(err, "connection closed with error: %s", closeErr.Error())
}
}

return nil, err
}

store(ctx, metadata.IsClient(a))
return conn, nil
}

Expand All @@ -102,6 +105,8 @@ func (a *authorizeClient) Close(ctx context.Context, conn *networkservice.Connec
if ok && p != nil {
ctx = peer.NewContext(ctx, p)
}
del(ctx, metadata.IsClient(a))

if err := a.policies.check(ctx, conn.GetPath()); err != nil {
return nil, err
}
Expand Down
134 changes: 134 additions & 0 deletions pkg/networkservice/common/authorize/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) 2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package authorize_test

import (
"context"

Check failure on line 20 in pkg/networkservice/common/authorize/client_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint / golangci-lint

File is not `goimports`-ed with -local github.com/networkservicemesh/sdk (goimports)
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"os"
"path"
"path/filepath"
"testing"

"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/chain"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/count"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
)

func TestAuthClient(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

dir := filepath.Clean(path.Join(os.TempDir(), t.Name()))
defer func() {
_ = os.RemoveAll(dir)
}()

err := os.MkdirAll(dir, os.ModePerm)
require.Nil(t, err)

policyPath := filepath.Clean(path.Join(dir, "policy.rego"))
err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm)
require.Nil(t, err)

suits := []struct {
name string
policyPath string
request *networkservice.NetworkServiceRequest
response *networkservice.Connection
denied bool
}{
{
name: "simple positive test",
policyPath: policyPath,
request: requestWithToken("allowed"),
denied: false,
},
{
name: "simple negative test",
policyPath: policyPath,
request: requestWithToken("not_allowed"),
denied: true,
},
}

for i := range suits {
s := suits[i]
t.Run(s.name, func(t *testing.T) {
client := chain.NewNetworkServiceClient(
metadata.NewClient(),
authorize.NewClient(authorize.WithPolicies(s.policyPath)),
)
checkResult := func(err error) {
if !s.denied {
require.Nil(t, err, "request expected to be not denied: ")
return
}
require.NotNil(t, err, "request expected to be denied")
s, ok := status.FromError(errors.Cause(err))
require.True(t, ok, "error without error status code"+err.Error())
require.Equal(t, s.Code(), codes.PermissionDenied, "wrong error status code")
}

_, err := client.Request(context.Background(), s.request)
checkResult(err)

_, err = client.Close(context.Background(), s.request.GetConnection())
checkResult(err)
})
}
}

func TestAuthClientFailedRefresh(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

dir := filepath.Clean(path.Join(os.TempDir(), t.Name()))
defer func() {
_ = os.RemoveAll(dir)
}()

err := os.MkdirAll(dir, os.ModePerm)
require.Nil(t, err)

policyPath := filepath.Clean(path.Join(dir, "policy.rego"))
err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm)
require.Nil(t, err)

counter := new(count.Client)
client := chain.NewNetworkServiceClient(
metadata.NewClient(),
authorize.NewClient(authorize.WithPolicies(policyPath)),
counter,
)

conn, err := client.Request(context.Background(), requestWithToken("allowed"))
require.Nil(t, err)

refreshRequest := requestWithToken("not_allowed")
_, err = client.Request(context.Background(), refreshRequest)
require.NotNil(t, err)
require.Equal(t, 0, counter.Closes())

_, err = client.Close(context.Background(), conn)
require.Nil(t, err)
require.Equal(t, 1, counter.Closes())
}
42 changes: 42 additions & 0 deletions pkg/networkservice/common/authorize/metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package authorize

import (
"context"

"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
)

type key struct{}

// store sets a flag stored in per Connection.Id metadata.
func store(ctx context.Context, isClient bool) {
metadata.Map(ctx, isClient).Store(key{}, struct{}{})
}

// load returns a flag stored in per Connection.Id metadata.
// It is used to determine a refresh
func load(ctx context.Context, isClient bool) (ok bool) {
_, ok = metadata.Map(ctx, isClient).Load(key{})
return
}

// del deletes a flag stored in per Connection.Id metadata.
func del(ctx context.Context, isClient bool) {
metadata.Map(ctx, isClient).Delete(key{})
}
5 changes: 0 additions & 5 deletions pkg/tools/opa/opainput.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ func ParseX509Cert(authInfo credentials.AuthInfo) *x509.Certificate {
}
}

if tlsInfo, ok := authInfo.(*credentials.TLSInfo); ok {
if len(tlsInfo.State.PeerCertificates) > 0 {
peerCert = tlsInfo.State.PeerCertificates[0]
}
}
return peerCert
}

Expand Down
3 changes: 0 additions & 3 deletions pkg/tools/spiffejwt/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ func TokenGeneratorFunc(source x509svid.Source, maxTokenLifeTime time.Duration)
if ownSVID.Certificates[0].NotAfter.Before(expireTime) {
expireTime = ownSVID.Certificates[0].NotAfter
}
if err != nil {
return "", time.Time{}, errors.Wrap(err, "Error creating Token")
}
claims := jwt.RegisteredClaims{
Subject: ownSVID.ID.String(),
ExpiresAt: jwt.NewNumericDate(expireTime),
Expand Down

0 comments on commit e3a83f6

Please sign in to comment.