Skip to content

Commit

Permalink
Merge pull request #19 from smallstep/herman/policy
Browse files Browse the repository at this point in the history
Add additional policy options and policy deduplication
  • Loading branch information
hslatman authored May 5, 2022
2 parents 7029556 + 4470bf6 commit 3eed2a0
Show file tree
Hide file tree
Showing 12 changed files with 405 additions and 76 deletions.
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
all: generate
all: generate test

test:
go test -race -coverpkg=./... -covermode=atomic ./...

generate:
protoc --proto_path=. --go_out=. --go-grpc_out=. --go_opt=paths=source_relative --go-grpc_opt=paths=source_relative provisioners.proto admin.proto config.proto eab.proto majordomo.proto policy.proto

.PHONY: all generate
.PHONY: all test generate
2 changes: 1 addition & 1 deletion admin.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 36 additions & 9 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ func NewContextWithAdmin(ctx context.Context, admin *Admin) context.Context {
return context.WithValue(ctx, adminContextKey, admin)
}

// AdminFromContext returns the Admin ctx carries.
// AdminFromContext returns an Admin if the ctx carries one and a
// bool indicating if an Admin is carried by the ctx.
func AdminFromContext(ctx context.Context) (a *Admin, ok bool) {
if a, ok = ctx.Value(adminContextKey).(*Admin); a == nil {
return nil, false
}
return
}

// MustAdminFromContext returns the Admin ctx carries.
//
// AdminFromContext panics in case ctx carries no Admin.
func AdminFromContext(ctx context.Context) *Admin {
// MustAdminFromContext panics in case ctx carries no Admin.
func MustAdminFromContext(ctx context.Context) *Admin {
return ctx.Value(adminContextKey).(*Admin)
}

Expand All @@ -28,10 +37,19 @@ func NewContextWithProvisioner(ctx context.Context, provisioner *Provisioner) co
return context.WithValue(ctx, provisionerContextKey, provisioner)
}

// ProvisionerFromContext returns the Provisioner ctx carries.
// ProvisionerFromContext returns a Provisioner if the ctx carries one and a
// bool indicating if a Provisioner is carried by the ctx.
func ProvisionerFromContext(ctx context.Context) (p *Provisioner, ok bool) {
if p, ok = ctx.Value(provisionerContextKey).(*Provisioner); p == nil {
return nil, false
}
return
}

// MustProvisionerFromContext returns the Provisioner ctx carries.
//
// ProvisionerFromContext panics in case ctx carries no Provisioner.
func ProvisionerFromContext(ctx context.Context) *Provisioner {
// MustProvisionerFromContext panics in case ctx carries no Provisioner.
func MustProvisionerFromContext(ctx context.Context) *Provisioner {
return ctx.Value(provisionerContextKey).(*Provisioner)
}

Expand All @@ -40,9 +58,18 @@ func NewContextWithExternalAccountKey(ctx context.Context, k *EABKey) context.Co
return context.WithValue(ctx, externalAccountKeyContextKey, k)
}

// ExternalAccountKeyFromContext returns the EABKey ctx carries.
// ExternalAccountKeyFromContext returns the EABKey if the ctx carries
// one and a bool indicating if an EABKey is carried by the ctx.
func ExternalAccountKeyFromContext(ctx context.Context) (k *EABKey, ok bool) {
if k, ok = ctx.Value(externalAccountKeyContextKey).(*EABKey); k == nil {
return nil, false
}
return
}

// MustExternalAccountKeyFromContext returns the EABKey ctx carries.
//
// ExternalAccountKeyFromContext panics in case ctx carries no EABKey.
func ExternalAccountKeyFromContext(ctx context.Context) *EABKey {
// MustExternalAccountKeyFromContext panics in case ctx carries no EABKey.
func MustExternalAccountKeyFromContext(ctx context.Context) *EABKey {
return ctx.Value(externalAccountKeyContextKey).(*EABKey)
}
67 changes: 56 additions & 11 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,89 @@ import (
func TestAdminFromContext(t *testing.T) {
t.Parallel()

exp := new(Admin)
// nil admin; expect false
var exp *Admin
got, ok := AdminFromContext(NewContextWithAdmin(context.Background(), exp))
assert.Same(t, exp, got)
assert.False(t, ok)

// non-nil admin; expect true
exp = new(Admin)
got, ok = AdminFromContext(NewContextWithAdmin(context.Background(), exp))
assert.Same(t, exp, got)
assert.True(t, ok)
}

func TestMustAdminFromContext(t *testing.T) {
t.Parallel()

got := AdminFromContext(NewContextWithAdmin(context.Background(), exp))
exp := new(Admin)
got := MustAdminFromContext(NewContextWithAdmin(context.Background(), exp))
assert.Same(t, exp, got)
}

func TestAdminFromContextPanics(t *testing.T) {
func TestMustAdminFromContextPanics(t *testing.T) {
t.Parallel()

assert.Panics(t, func() { AdminFromContext(context.Background()) })
assert.Panics(t, func() { MustAdminFromContext(context.Background()) })
}

func TestProvisionerFromContext(t *testing.T) {
t.Parallel()

exp := new(Provisioner)
// nil Provisioner; expect false
var exp *Provisioner
got, ok := ProvisionerFromContext(NewContextWithProvisioner(context.Background(), exp))
assert.Same(t, exp, got)
assert.False(t, ok)

// non-nil Provisioner; expect true
exp = new(Provisioner)
got, ok = ProvisionerFromContext(NewContextWithProvisioner(context.Background(), exp))
assert.Same(t, exp, got)
assert.True(t, ok)
}

func TestMustProvisionerFromContext(t *testing.T) {
t.Parallel()

got := ProvisionerFromContext(NewContextWithProvisioner(context.Background(), exp))
exp := new(Provisioner)
got := MustProvisionerFromContext(NewContextWithProvisioner(context.Background(), exp))
assert.Same(t, exp, got)
}

func TestProvisionerFromContextPanics(t *testing.T) {
func TestMustProvisionerFromContextPanics(t *testing.T) {
t.Parallel()

assert.Panics(t, func() { ProvisionerFromContext(context.Background()) })
assert.Panics(t, func() { MustProvisionerFromContext(context.Background()) })
}

func TestExternalAccountKeyFromContext(t *testing.T) {
t.Parallel()

exp := new(EABKey)
// nil EABKey; expect false
var exp *EABKey
got, ok := ExternalAccountKeyFromContext(NewContextWithExternalAccountKey(context.Background(), exp))
assert.Same(t, exp, got)
assert.False(t, ok)

got := ExternalAccountKeyFromContext(NewContextWithExternalAccountKey(context.Background(), exp))
// non-nil EABKey; expect true
exp = new(EABKey)
got, ok = ExternalAccountKeyFromContext(NewContextWithExternalAccountKey(context.Background(), exp))
assert.Same(t, exp, got)
assert.True(t, ok)
}

func TestMustExternalAccountKeyFromContext(t *testing.T) {
t.Parallel()

exp := new(EABKey)
got := MustExternalAccountKeyFromContext(NewContextWithExternalAccountKey(context.Background(), exp))
assert.Same(t, exp, got)
}

func TestExternalAccountKeyFromContextPanics(t *testing.T) {
t.Parallel()

assert.Panics(t, func() { ExternalAccountKeyFromContext(context.Background()) })
assert.Panics(t, func() { MustExternalAccountKeyFromContext(context.Background()) })
}
2 changes: 1 addition & 1 deletion eab.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion majordomo.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

71 changes: 71 additions & 0 deletions policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package linkedca

// Deduplicate removes duplicate values from the Policy
func (p *Policy) Deduplicate() {
if p == nil {
return
}
if x509 := p.GetX509(); x509 != nil {
if allow := x509.GetAllow(); allow != nil {
allow.Dns = removeDuplicates(allow.Dns)
allow.Ips = removeDuplicates(allow.Ips)
allow.Emails = removeDuplicates(allow.Emails)
allow.Uris = removeDuplicates(allow.Uris)
}
if deny := p.GetX509().GetDeny(); deny != nil {
deny.Dns = removeDuplicates(deny.Dns)
deny.Ips = removeDuplicates(deny.Ips)
deny.Emails = removeDuplicates(deny.Emails)
deny.Uris = removeDuplicates(deny.Uris)
}
}
if ssh := p.GetSsh(); ssh != nil {
if host := ssh.GetHost(); host != nil {
if allow := host.GetAllow(); allow != nil {
allow.Dns = removeDuplicates(allow.Dns)
allow.Ips = removeDuplicates(allow.Ips)
allow.Principals = removeDuplicates(allow.Principals)
}
if deny := host.GetDeny(); deny != nil {
deny.Dns = removeDuplicates(deny.Dns)
deny.Ips = removeDuplicates(deny.Ips)
deny.Principals = removeDuplicates(deny.Principals)
}
}
if user := ssh.GetUser(); user != nil {
if allow := user.GetAllow(); allow != nil {
allow.Emails = removeDuplicates(allow.Emails)
allow.Principals = removeDuplicates(allow.Principals)
}
if deny := user.GetDeny(); deny != nil {
deny.Emails = removeDuplicates(deny.Emails)
deny.Principals = removeDuplicates(deny.Principals)
}
}
}
}

// removeDuplicates returns a new slice of strings with
// duplicate values removed. It retains the order of elements
// in the source slice.
func removeDuplicates(tokens []string) (ret []string) {

// no need to remove dupes; return original
if len(tokens) <= 1 {
return tokens
}

keys := make(map[string]struct{}, len(tokens))

ret = make([]string, 0, len(tokens))
for _, item := range tokens {
if _, ok := keys[item]; ok {
continue
}

keys[item] = struct{}{}
ret = append(ret, item)
}

return
}
Loading

0 comments on commit 3eed2a0

Please sign in to comment.