Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

sqlib/sqhook: allow attaching multiple rules per hook point #136

Merged
merged 2 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/backend/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ type Rule struct {
Test bool `json:"test"`
Block bool `json:"block"`
AttackType string `json:"attack_type"`
Priority int `json:"priority"`
}

type RuleConditions struct{}
Expand Down
3 changes: 1 addition & 2 deletions internal/rule/callback/add-security-headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import (
// to be attached to compatible HTTP protection middlewares such as
// `protection/http`. It adds HTTP headers provided by the rule's configuration.
func NewAddSecurityHeadersCallback(rule RuleFace, cfg NativeCallbackConfig) (sqhook.PrologCallback, error) {
sqassert.NotNil(rule)
sqassert.NotNil(cfg)
sqassert.NotNil(rule, cfg)
var headers http.Header
data, ok := cfg.Data().([]interface{})
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion internal/rule/instrumentation.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type InstrumentationFace interface {
}

type HookFace interface {
Attach(prolog sqhook.PrologCallback) error
Attach(prologs ...sqhook.PrologCallback) error
}

type defaultInstrumentationImpl struct{}
Expand Down
100 changes: 71 additions & 29 deletions internal/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package rule
import (
"crypto/ecdsa"
"io"
"sort"

"github.com/sqreen/go-agent/internal/backend/api"
"github.com/sqreen/go-agent/internal/metrics"
Expand All @@ -35,7 +36,7 @@ type Engine struct {
// at run time by atomically replacing a running rule.
// TODO: write a test to check two HookFaces are correctly comparable
// to find back a hook
hooks hookDescriptors
hooks hookDescriptorMap
packID string
enabled bool
metricsEngine *metrics.Engine
Expand Down Expand Up @@ -79,15 +80,15 @@ func (e *Engine) PackID() string {
// them by atomically modifying the hooks, and removing what is left.
func (e *Engine) SetRules(packID string, rules []api.Rule) {
// Create the new rule descriptors and replace the existing ones
var ruleDescriptors hookDescriptors
var ruleDescriptors hookDescriptorMap
if len(rules) > 0 {
e.logger.Debugf("security rules: loading rules from pack `%s`", packID)
ruleDescriptors = newHookDescriptors(e, rules)
}
e.setRules(packID, ruleDescriptors)
}

func (e *Engine) setRules(packID string, descriptors hookDescriptors) {
func (e *Engine) setRules(packID string, descriptors hookDescriptorMap) {
// Firstly update already enabled hookpoints with their new callbacks in order
// to avoid having a blank moment without any callback set. This case happens
// when a rule is updated.
Expand All @@ -96,7 +97,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) {
if e.enabled {
// Attach the callback to the hook, possibly overwriting the previous one.
e.logger.Debugf("security rules: attaching callback to `%s`", hook)
err := hook.Attach(descr.callback)
err := hook.Attach(descr.callbacks...)
if err != nil {
e.logger.Error(sqerrors.Wrapf(err, "security rules: could not attach the prolog callback to `%s`", hook))
continue
Expand Down Expand Up @@ -135,11 +136,11 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) {
// newHookDescriptors walks the list of received rules and creates the map of
// hook descriptors indexed by their hook pointer. A hook descriptor contains
// all it takes to enable and disable rules at run time.
func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors {
func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptorMap {
logger := e.logger

// Create and configure the list of callbacks according to the given rules
var hookDescriptors = make(hookDescriptors)
var hookDescriptors = make(hookDescriptorMap)
for i := len(rules) - 1; i >= 0; i-- {
r := rules[i]
// Verify the signature
Expand Down Expand Up @@ -168,6 +169,8 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors {
continue
}

// Create the prolog callback
var prolog sqhook.PrologCallback
switch hookpoint.Strategy {
case "", "native":
cfg, err := newNativeCallbackConfig(&r)
Expand All @@ -176,26 +179,23 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors {
continue
}

prolog, err := NewNativeCallback(hookpoint.Callback, callbackContext, cfg)
prolog, err = NewNativeCallback(hookpoint.Callback, callbackContext, cfg)
if err != nil {
logger.Error(sqerrors.Wrapf(err, "security rules: rule `%s`: callback constructor", r.Name))
continue
}
// Create the descriptor with everything required to be able to enable or
// disable it afterwards.
hookDescriptors.Set(hook, prolog)

case "reflected":
prolog, err := NewReflectedCallback(hookpoint.Callback, callbackContext, &r)
prolog, err = NewReflectedCallback(hookpoint.Callback, callbackContext, &r)
if err != nil {
logger.Error(sqerrors.Wrapf(err, "security rules: rule `%s`: callback constructor", r.Name))
continue
}
// Create the descriptor with everything required to be able to enable or
// disable it afterwards.
hookDescriptors.Set(hook, prolog)
}

// Create the descriptor with everything required to be able to enable or
// disable it afterwards.
hookDescriptors.Add(hook, prolog, r.Priority)
}
// Nothing in the end
if len(hookDescriptors) == 0 {
Expand All @@ -207,9 +207,8 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors {
// Enable the hooks of the ongoing configured rules.
func (e *Engine) Enable() {
for hook, descr := range e.hooks {
prolog := descr.callback
e.logger.Debugf("security rules: attaching callback to hook `%s`", hook)
if err := hook.Attach(prolog); err != nil {
if err := hook.Attach(descr.callbacks...); err != nil {
e.logger.Error(sqerrors.Wrapf(err, "security rules: could not attach the callback to hook `%v`", hook))
}
}
Expand All @@ -235,23 +234,66 @@ func (e *Engine) Count() int {
return len(e.hooks)
}

type callbackWrapper struct {
callback sqhook.PrologCallback
}
type (
hookDescriptorMap map[HookFace]hookDescriptor

func (c callbackWrapper) Close() error {
if closer, ok := c.callback.(io.Closer); ok {
return closer.Close()
hookDescriptor struct {
priorities []int
callbacks []sqhook.PrologCallback
closers []io.Closer
}
)

func (m hookDescriptorMap) Add(hook HookFace, callback sqhook.PrologCallback, priority int) {
d, exists := m[hook]
closer, _ := callback.(io.Closer)

if !exists {
// First insertion
var closers []io.Closer
if closer != nil {
closers = []io.Closer{closer}
}
m[hook] = hookDescriptor{
priorities: []int{priority},
callbacks: []sqhook.PrologCallback{callback},
closers: closers,
}
return
}
return nil
}

type hookDescriptors map[HookFace]callbackWrapper
// Not the first insertion.
// Look for the callback position i per ascending priority order
i := sort.Search(len(d.priorities), func(i int) bool {
return d.priorities[i] > priority
})

func (m hookDescriptors) Set(hook HookFace, prolog sqhook.PrologCallback) {
m[hook] = callbackWrapper{prolog}
// Update the list of priorities
d.priorities = append(d.priorities, 0)
copy(d.priorities[i+1:], d.priorities[i:])
d.priorities[i] = priority

// Update the list of closers
if closer != nil {
d.closers = append(d.closers, closer)
}

// Update the list of callbacks
d.callbacks = append(d.callbacks, nil)
copy(d.callbacks[i+1:], d.callbacks[i:])
d.callbacks[i] = callback

// Update the hook descriptor map entry with the new value
m[hook] = d
}

func (m hookDescriptors) Get(hook HookFace) callbackWrapper {
return m[hook]
func (d hookDescriptor) Close() error {
var errs sqerrors.ErrorCollection
for _, c := range d.closers {
err := c.Close()
if err != nil {
errs.Add(err)
}
}
return errs.ToError()
}
20 changes: 16 additions & 4 deletions internal/rule/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func (i *instrumentationMockup) Health(expectedVersion string) error {

type hookMockup struct{ mock.Mock }

var _ rule.HookFace = &hookMockup{}

func (i *instrumentationMockup) Find(symbol string) (rule.HookFace, error) {
res := i.Called(symbol)
err := res.Error(1)
Expand All @@ -47,12 +49,22 @@ func (i *instrumentationMockup) ExpectFind(symbol string) *mock.Call {
return i.On("Find", symbol)
}

func (h *hookMockup) Attach(prolog sqhook.PrologCallback) error {
return h.Called(prolog).Error(0)
func (h *hookMockup) Attach(prologs ...sqhook.PrologCallback) error {
return h.Called(prologs).Error(0)
}

func (h *hookMockup) ExpectAttach(prolog interface{}) *mock.Call {
return h.On("Attach", prolog)
func (h *hookMockup) ExpectAttach(prologs ...interface{}) *mock.Call {
var args interface{}
if l := len(prologs); l == 1 && prologs[0] == mock.Anything {
args = prologs[0]
} else {
prologArgs := make([]sqhook.PrologCallback, l)
for i, p := range prologs {
prologArgs[i] = p
}
args = prologArgs
}
return h.On("Attach", args)
}

func (h *hookMockup) PrologFuncType() reflect.Type {
Expand Down
76 changes: 76 additions & 0 deletions internal/rule/rule_unit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2016 - 2020 Sqreen. All Rights Reserved.
// Please refer to our terms for more information:
// https://www.sqreen.io/terms.html

package rule

import (
"io"
"testing"

"github.com/sqreen/go-agent/internal/sqlib/sqhook"
"github.com/stretchr/testify/require"
)

type hookMockup struct{}

func (h hookMockup) Attach(...sqhook.PrologCallback) error {
panic("should not be called")
// TODO: better API to avoid that? the map only needs a "comparable" key and
// doesn't matter about the hook interface.
}

func TestHookDescriptors(t *testing.T) {
// Not actual callbacks but enough for this unit test.
// We need to use distinct types to correctly check the ordering.

t.Run("multiple callbacks having the same priority", func(t *testing.T) {
var m = hookDescriptorMap{}
key := hookMockup{}
m.Add(key, 1, 1)
m.Add(key, 2, 1)
m.Add(key, 3, 1)
m.Add(key, 4, 1)
d := m[key]
require.Equal(t, []int{1, 1, 1, 1}, d.priorities)
require.Equal(t, []sqhook.PrologCallback{1, 2, 3, 4}, d.callbacks)
require.Nil(t, d.closers)
})

t.Run("multiple callbacks having distinct priorities", func(t *testing.T) {
var m = hookDescriptorMap{}
key := hookMockup{}

m.Add(key, 3, 2)
m.Add(key, 5, 3)
m.Add(key, 4, 2)
m.Add(key, 1, 1)
m.Add(key, 6, 3)
m.Add(key, 2, 1)
d := m[key]
require.Equal(t, []int{1, 1, 2, 2, 3, 3}, d.priorities)
require.Equal(t, []sqhook.PrologCallback{1, 2, 3, 4, 5, 6}, d.callbacks)
require.Nil(t, d.closers)
})

t.Run("multiple callbacks with close methods", func(t *testing.T) {
var m = hookDescriptorMap{}
key := hookMockup{}
m.Add(key, myFakeCallback(7), 10)
m.Add(key, 3, 2)
m.Add(key, myFakeCallback(1), 1)
m.Add(key, 2, 1)
m.Add(key, myFakeCallback(5), 3)
m.Add(key, 4, 2)
m.Add(key, 6, 3)

d := m[key]
require.Equal(t, []int{1, 1, 2, 2, 3, 3, 10}, d.priorities)
require.Equal(t, []sqhook.PrologCallback{myFakeCallback(1), 2, 3, 4, myFakeCallback(5), 6, myFakeCallback(7)}, d.callbacks)
require.Equal(t, []io.Closer{myFakeCallback(7), myFakeCallback(1), myFakeCallback(5)}, d.closers)
})
}

type myFakeCallback int

func (m myFakeCallback) Close() error { return nil }
8 changes: 5 additions & 3 deletions internal/sqlib/sqassert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ func NoError(err error) {
}
}

func NotNil(v interface{}) {
if v == nil {
doPanic(sqerrors.New("sqassert: unexpected nil value"))
func NotNil(v ...interface{}) {
for _, v := range v {
if v == nil {
doPanic(sqerrors.New("sqassert: unexpected nil value"))
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions internal/sqlib/sqassert/assert_disabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

package sqassert

func True(bool) {}
func NoError(error) {}
func NotNil(interface{}) {}
func True(bool) {}
func NoError(error) {}
func NotNil(...interface{}) {}
24 changes: 24 additions & 0 deletions internal/sqlib/sqerrors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package sqerrors

import (
"fmt"
"strings"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -171,3 +172,26 @@ func Timestamp(err error) (t time.Time, ok bool) {
}
return time.Time{}, false
}

type ErrorCollection []error

func (c ErrorCollection) Error() string {
var s strings.Builder
s.WriteString("multiple errors occurred:")
for i, e := range c {
fmt.Fprintf(&s, " (error %d) %s;", i+1, e.Error())
}
// Return the build string without the trailing `;`
return s.String()[:s.Len()-1]
}

func (c *ErrorCollection) Add(e error) {
*c = append(*c, e)
}

func (c ErrorCollection) ToError() error {
if len(c) == 0 {
return nil
}
return c
}
Loading