From a2d2ea5e1b26b08b37271374cc2381e652d54871 Mon Sep 17 00:00:00 2001 From: Alexander Baryshnikov Date: Mon, 20 Jul 2020 20:37:14 +0800 Subject: [PATCH] add basic policies --- application/interfaces.go | 19 ++++ application/policy/checker.go | 26 +++++ application/policy/config.go | 66 +++++++++++++ application/policy/impl.go | 169 ++++++++++++++++++++++++++++++++ application/policy/impl_test.go | 102 +++++++++++++++++++ application/types.go | 13 +++ types/utils.go | 8 ++ 7 files changed, 403 insertions(+) create mode 100644 application/policy/checker.go create mode 100644 application/policy/config.go create mode 100644 application/policy/impl.go create mode 100644 application/policy/impl_test.go diff --git a/application/interfaces.go b/application/interfaces.go index 9356065..541c910 100644 --- a/application/interfaces.go +++ b/application/interfaces.go @@ -142,3 +142,22 @@ type Queues interface { // Find queues linked to lambda Find(targetLambda string) []Queue } + +// Manage policies for the all kind of resource. +// Lambda can have only one policy at one time, but one policy can be used by many lambdas. +type Policies interface { + // List all policies + List() []Policy + // Create new policy + Create(policy string, definition PolicyDefinition) (*Policy, error) + // Remove policy + Remove(policy string) error + // Update policy definition + Update(policy string, definition PolicyDefinition) error + // Apply policy for the resource + Apply(lambda string, policy string) error + // Clear applied policy for the lambda + Clear(lambda string) error + // Inspect request according policy (if applied). Returns null if all checks successful + Inspect(lambda string, request *types.Request) error +} diff --git a/application/policy/checker.go b/application/policy/checker.go new file mode 100644 index 0000000..b8011ef --- /dev/null +++ b/application/policy/checker.go @@ -0,0 +1,26 @@ +package policy + +import ( + "fmt" + "github.com/reddec/trusted-cgi/application" + "github.com/reddec/trusted-cgi/types" + "net" +) + +func checkPolicy(policy application.PolicyDefinition, req *types.Request) error { + host, _, _ := net.SplitHostPort(req.RemoteAddress) + if len(policy.AllowedIP) > 0 && !policy.AllowedIP.Has(host) { + return fmt.Errorf("IP restricted") + } + if len(policy.AllowedOrigin) > 0 && !policy.AllowedOrigin.Has(req.Headers["Origin"]) { + return fmt.Errorf("origin restricted") + } + + if !policy.Public { + _, ok := policy.Tokens[req.Headers["Authorization"]] + if !ok { + return fmt.Errorf("token restricted") + } + } + return nil +} diff --git a/application/policy/config.go b/application/policy/config.go new file mode 100644 index 0000000..dab7827 --- /dev/null +++ b/application/policy/config.go @@ -0,0 +1,66 @@ +package policy + +import ( + "github.com/reddec/trusted-cgi/application" + "github.com/reddec/trusted-cgi/internal" + "os" + "sync" +) + +type naiveFileStorePayload struct { + Policies []application.Policy `json:"policies"` +} + +func FileConfig(filename string) *naiveFileStore { + return &naiveFileStore{file: filename} +} + +type naiveFileStore struct { + file string + lock sync.RWMutex +} + +func (nfs *naiveFileStore) SetPolicies(policies []application.Policy) error { + nfs.lock.Lock() + defer nfs.lock.Unlock() + return internal.AtomicWriteJson(nfs.file, &naiveFileStorePayload{Policies: policies}) +} + +func (nfs *naiveFileStore) GetPolicies() ([]application.Policy, error) { + nfs.lock.RLock() + defer nfs.lock.RUnlock() + var pd naiveFileStorePayload + err := internal.ReadJson(nfs.file, &pd) + if err == nil { + return pd.Policies, nil + } + if os.IsNotExist(err) { + return nil, nil + } + return nil, err +} + +func Mock(policies ...application.Policy) *mockStore { + return &mockStore{policies: policies} +} + +type mockStore struct { + lock sync.RWMutex + policies []application.Policy +} + +func (msc *mockStore) SetPolicies(policies []application.Policy) error { + msc.lock.Lock() + defer msc.lock.Unlock() + msc.policies = make([]application.Policy, len(policies)) + copy(msc.policies, policies) + return nil +} + +func (msc *mockStore) GetPolicies() ([]application.Policy, error) { + msc.lock.RLock() + defer msc.lock.RUnlock() + out := make([]application.Policy, len(msc.policies)) + copy(out, msc.policies) + return out, nil +} diff --git a/application/policy/impl.go b/application/policy/impl.go new file mode 100644 index 0000000..fd33c3e --- /dev/null +++ b/application/policy/impl.go @@ -0,0 +1,169 @@ +package policy + +import ( + "fmt" + "github.com/reddec/trusted-cgi/application" + "github.com/reddec/trusted-cgi/types" + "sync" +) + +// Store contains policies configuration for reload +type Store interface { + // Save policies list + SetPolicies(policies []application.Policy) error + // Load policies list + GetPolicies() ([]application.Policy, error) +} + +func New(store Store) (*policiesImpl, error) { + impl := &policiesImpl{ + store: store, + policiesByID: map[string]*application.Policy{}, + policiesByLambda: map[string]string{}, + } + return impl, impl.load() +} + +type policiesImpl struct { + store Store + lock sync.RWMutex + policiesByID map[string]*application.Policy + policiesByLambda map[string]string +} + +func (policies *policiesImpl) load() error { + list, err := policies.store.GetPolicies() + if err != nil { + return err + } + for _, item := range list { + policies.policiesByID[item.ID] = &item + for lambda := range item.Lambdas { + policies.policiesByLambda[lambda] = item.ID + } + } + return nil +} + +func (policies *policiesImpl) List() []application.Policy { + policies.lock.RLock() + defer policies.lock.RUnlock() + return policies.unsafeList() +} + +func (policies *policiesImpl) Create(policy string, definition application.PolicyDefinition) (*application.Policy, error) { + policies.lock.Lock() + defer policies.lock.Unlock() + _, exist := policies.policiesByID[policy] + if exist { + return nil, fmt.Errorf("policy %s already exists", policy) + } + info := &application.Policy{ + ID: policy, + Definition: definition, + Lambdas: make(types.JsonStringSet), + } + policies.policiesByID[policy] = info + return info, policies.store.SetPolicies(policies.unsafeList()) +} + +func (policies *policiesImpl) Remove(policy string) error { + policies.lock.Lock() + defer policies.lock.Unlock() + info, exist := policies.policiesByID[policy] + if !exist { + return fmt.Errorf("policy %s does not exists", policy) + } + for lambda := range info.Lambdas { + delete(policies.policiesByLambda, lambda) + } + delete(policies.policiesByID, policy) + return policies.store.SetPolicies(policies.unsafeList()) +} + +func (policies *policiesImpl) Update(policy string, definition application.PolicyDefinition) error { + policies.lock.Lock() + defer policies.lock.Unlock() + info, exist := policies.policiesByID[policy] + if !exist { + return fmt.Errorf("policy %s does not exists", policy) + } + info.Definition = definition + return policies.store.SetPolicies(policies.unsafeList()) +} + +func (policies *policiesImpl) Apply(lambda string, policy string) error { + policies.lock.Lock() + defer policies.lock.Unlock() + info, exists := policies.policiesByID[policy] + if !exists { + return fmt.Errorf("policy %s does not exist", policy) + } + if info.Lambdas.Has(lambda) { + // already applied + return nil + } + policies.unsafeUnlink(lambda) + info.Lambdas.Set(lambda) + policies.policiesByLambda[lambda] = policy + return policies.store.SetPolicies(policies.unsafeList()) +} + +func (policies *policiesImpl) Inspect(lambda string, request *types.Request) error { + policy, applicable, err := policies.findPolicy(lambda) + if err != nil { + return err + } + if !applicable { + return nil + } + return checkPolicy(policy, request) +} + +func (policies *policiesImpl) Clear(lambda string) error { + policies.lock.Lock() + defer policies.lock.Unlock() + if !policies.unsafeUnlink(lambda) { + return nil + } + return policies.store.SetPolicies(policies.unsafeList()) +} + +func (policies *policiesImpl) unsafeUnlink(lambda string) bool { + policyId, hasPolicy := policies.policiesByLambda[lambda] + if !hasPolicy { + return false + } + // remove direct ref + delete(policies.policiesByLambda, lambda) + + // remove back ref + if policy, exist := policies.policiesByID[policyId]; exist { + policy.Lambdas.Del(lambda) + } + return true +} + +func (policies *policiesImpl) unsafeList() []application.Policy { + var ans = make([]application.Policy, 0, len(policies.policiesByID)) + for _, policy := range policies.policiesByID { + ans = append(ans, *policy) + } + return ans +} + +func (policies *policiesImpl) findPolicy(lambda string) (policy application.PolicyDefinition, applicable bool, err error) { + policies.lock.RLock() + defer policies.lock.RUnlock() + policyId, exists := policies.policiesByLambda[lambda] + if !exists { + applicable = false + return // no applied policy + } + info, exists := policies.policiesByID[policyId] + if !exists { + err = fmt.Errorf("corrupted policy data: lambda %s linked to unknown policy %s", lambda, policyId) + return + } + return info.Definition, true, nil +} diff --git a/application/policy/impl_test.go b/application/policy/impl_test.go new file mode 100644 index 0000000..429eaef --- /dev/null +++ b/application/policy/impl_test.go @@ -0,0 +1,102 @@ +package policy + +import ( + "bytes" + "github.com/reddec/trusted-cgi/application" + "github.com/reddec/trusted-cgi/types" + "github.com/stretchr/testify/assert" + "io/ioutil" + "testing" +) + +func TestNew(t *testing.T) { + policy, err := New(Mock(application.Policy{ + ID: "foo", + Definition: application.PolicyDefinition{ + Public: false, + Tokens: map[string]string{ + "DEADBEAF": "Consumer 1", + "BEAFDEAD": "Consumer 2", + }, + }, + Lambdas: map[string]bool{ + "lambda-1": true, + "lambda-2": true, + }, + })) + if err != nil { + t.Error(err) + return + } + t.Run("no applied policy", func(t *testing.T) { + err := policy.Inspect("lambda-3", mockRequest("hello")) + if err != nil { + t.Error(err) + } + }) + t.Run("valid token", func(t *testing.T) { + req := mockRequest("hello") + req.Headers["Authorization"] = "DEADBEAF" + err := policy.Inspect("lambda-1", req) + if err != nil { + t.Error(err) + } + }) + t.Run("invalid token", func(t *testing.T) { + req := mockRequest("hello") + req.Headers["Authorization"] = "1111" + err := policy.Inspect("lambda-1", req) + if err == nil { + t.Error("should fail") + } + }) + t.Run("list policies", func(t *testing.T) { + list := policy.List() + assert.Len(t, list, 1) + assert.Equal(t, "foo", list[0].ID) + assert.Equal(t, list[0].Lambdas, types.StringSet("lambda-1", "lambda-2")) + }) + t.Run("clear", func(t *testing.T) { + err := policy.Clear("lambda-2") + assert.NoError(t, err) + list := policy.List() + assert.Len(t, list, 1) + assert.Equal(t, "foo", list[0].ID) + assert.Equal(t, list[0].Lambdas, types.StringSet("lambda-1")) + }) + t.Run("apply", func(t *testing.T) { + err := policy.Apply("lambda-4", "foo") + assert.NoError(t, err) + list := policy.List() + assert.Len(t, list, 1) + assert.Equal(t, "foo", list[0].ID) + assert.Contains(t, list[0].Lambdas, "lambda-4") + }) + t.Run("update", func(t *testing.T) { + err := policy.Update("foo", application.PolicyDefinition{ + AllowedOrigin: types.StringSet("google"), + Public: true, + }) + assert.NoError(t, err) + req := mockRequest("hello") + req.Headers["Origin"] = "google" + err = policy.Inspect("lambda-1", req) + assert.NoError(t, err) + }) +} + +func mockRequest(payload string) *types.Request { + return &types.Request{ + Method: "POST", + URL: "http://example.com:8889/sample/" + payload, + Path: "/sample/" + payload, + RemoteAddress: "127.0.0.2:9992", + Form: map[string]string{ + "USER": "user1", + }, + Headers: map[string]string{ + "Content-Type": "text/plain", + }, + Body: ioutil.NopCloser(bytes.NewBufferString(payload)), + } +} diff --git a/application/types.go b/application/types.go index ce7e59c..eadba8b 100644 --- a/application/types.go +++ b/application/types.go @@ -55,3 +55,16 @@ type Queue struct { Retry int `json:"retry"` // number of additional attempts Interval types.JsonDuration `json:"interval"` // delay between attempts } + +type PolicyDefinition struct { + AllowedIP types.JsonStringSet `json:"allowed_ip,omitempty"` // limit incoming connections from list of IP + AllowedOrigin types.JsonStringSet `json:"allowed_origin,omitempty"` // limit incoming connections by origin header + Public bool `json:"public"` // if public, tokens are ignores + Tokens map[string]string `json:"tokens,omitempty"` // limit request by value in Authorization header (token => title) +} + +type Policy struct { + ID string `json:"id"` + Definition PolicyDefinition `json:"definition"` + Lambdas types.JsonStringSet `json:"lambdas"` +} diff --git a/types/utils.go b/types/utils.go index b970df8..3dac88e 100644 --- a/types/utils.go +++ b/types/utils.go @@ -4,6 +4,14 @@ import "encoding/json" type JsonStringSet map[string]bool +func StringSet(values ...string) JsonStringSet { + var ans = make(JsonStringSet) + for _, v := range values { + ans[v] = true + } + return ans +} + func (s *JsonStringSet) MarshalJSON() ([]byte, error) { var keys = make([]string, 0, len(*s)) for k := range *s {