Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v10] Backport #13506 #13720

Merged
merged 3 commits into from
Jun 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 83 additions & 55 deletions lib/auth/session_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import (

type startTestCase struct {
name string
host types.Role
sessionKind types.SessionKind
host []types.Role
sessionKinds []types.SessionKind
participants []SessionAccessContext
expected bool
expected []bool
}

func successStartTestCase(t *testing.T) startTestCase {
Expand All @@ -39,22 +39,22 @@ func successStartTestCase(t *testing.T) startTestCase {

hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{
Filter: "contains(user.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind)},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 2,
OnLeave: types.OnSessionLeaveTerminate,
Modes: []string{"peer"},
}})

participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind)},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{string("*")},
}})

return startTestCase{
name: "success",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "success",
host: []types.Role{hostRole},
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participants: []SessionAccessContext{
{
Username: "participant",
Expand All @@ -67,7 +67,7 @@ func successStartTestCase(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: true,
expected: []bool{true, true},
}
}

Expand All @@ -79,21 +79,21 @@ func failCountStartTestCase(t *testing.T) startTestCase {

hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{
Filter: "contains(user.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind)},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 3,
Modes: []string{"peer"},
}})

participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind)},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{string("*")},
}})

return startTestCase{
name: "failCount",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "failCount",
host: []types.Role{hostRole},
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participants: []SessionAccessContext{
{
Username: "participant",
Expand All @@ -106,7 +106,7 @@ func failCountStartTestCase(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: false,
expected: []bool{false, false},
}
}

Expand All @@ -122,10 +122,10 @@ func succeedDiscardPolicySetStartTestCase(t *testing.T) startTestCase {
}})

return startTestCase{
name: "succeedDiscardPolicySet",
host: hostRole,
sessionKind: types.SSHSessionKind,
expected: true,
name: "succeedDiscardPolicySet",
host: []types.Role{hostRole},
sessionKinds: []types.SessionKind{types.SSHSessionKind},
expected: []bool{true},
}
}

Expand All @@ -149,9 +149,9 @@ func failFilterStartTestCase(t *testing.T) startTestCase {
}})

return startTestCase{
name: "failFilter",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "failFilter",
host: []types.Role{hostRole},
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participants: []SessionAccessContext{
{
Username: "participant",
Expand All @@ -164,7 +164,7 @@ func failFilterStartTestCase(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: false,
expected: []bool{false},
}
}

Expand All @@ -178,21 +178,28 @@ func TestSessionAccessStart(t *testing.T) {

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind)
result, _, err := evaluator.FulfilledFor(testCase.participants)
require.NoError(t, err)
require.Equal(t, testCase.expected, result)
var policies []*types.SessionTrackerPolicySet
for _, role := range testCase.host {
policySet := role.GetSessionPolicySet()
policies = append(policies, &policySet)
}

for i, kind := range testCase.sessionKinds {
evaluator := NewSessionAccessEvaluator(policies, kind)
result, _, err := evaluator.FulfilledFor(testCase.participants)
require.NoError(t, err)
require.Equal(t, testCase.expected[i], result)
}
})
}
}

type joinTestCase struct {
name string
host types.Role
sessionKind types.SessionKind
participant SessionAccessContext
expected bool
name string
host types.Role
sessionKinds []types.SessionKind
participant SessionAccessContext
expected []bool
}

func successJoinTestCase(t *testing.T) joinTestCase {
Expand All @@ -203,19 +210,19 @@ func successJoinTestCase(t *testing.T) joinTestCase {

participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind)},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{string("*")},
}})

return joinTestCase{
name: "success",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "success",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: true,
expected: []bool{true, true},
}
}

Expand All @@ -227,19 +234,19 @@ func successGlobJoinTestCase(t *testing.T) joinTestCase {

participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{"*"},
Kinds: []string{string(types.SSHSessionKind)},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{string("*")},
}})

return joinTestCase{
name: "success",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "success",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: true,
expected: []bool{true, true},
}
}

Expand All @@ -250,14 +257,14 @@ func failRoleJoinTestCase(t *testing.T) joinTestCase {
require.NoError(t, err)

return joinTestCase{
name: "failRole",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "failRole",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: false,
expected: []bool{false, false},
}
}

Expand All @@ -274,14 +281,32 @@ func failKindJoinTestCase(t *testing.T) joinTestCase {
}})

return joinTestCase{
name: "failKind",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "failKind",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: []bool{false},
}
}

func versionDefaultJoinTestCase(t *testing.T) joinTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV5{})
require.NoError(t, err)
participantRole, err := types.NewRoleV3("participant", types.RoleSpecV5{})
require.NoError(t, err)

return joinTestCase{
name: "failVersion",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: false,
expected: []bool{true, false},
}
}

Expand All @@ -291,14 +316,17 @@ func TestSessionAccessJoin(t *testing.T) {
successGlobJoinTestCase(t),
failRoleJoinTestCase(t),
failKindJoinTestCase(t),
versionDefaultJoinTestCase(t),
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected, len(result) > 0)
for i, kind := range testCase.sessionKinds {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected[i], len(result) > 0)
}
})
}
}