diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go index 3c53f684..2ffe18dd 100644 --- a/internal/backend/api/api.go +++ b/internal/backend/api/api.go @@ -182,6 +182,7 @@ type Rule struct { Test bool `json:"test"` Block bool `json:"block"` AttackType string `json:"attack_type"` + Priority int `json:"priority"` } type RuleConditions struct{} diff --git a/internal/rule/callback/add-security-headers.go b/internal/rule/callback/add-security-headers.go index a9c4068e..b343eb2e 100644 --- a/internal/rule/callback/add-security-headers.go +++ b/internal/rule/callback/add-security-headers.go @@ -19,8 +19,7 @@ import ( // to be attached to compatible HTTP protection middlewares such as // `protection/http`. It adds HTTP headers provided by the rule's configuration. func NewAddSecurityHeadersCallback(rule RuleFace, cfg NativeCallbackConfig) (sqhook.PrologCallback, error) { - sqassert.NotNil(rule) - sqassert.NotNil(cfg) + sqassert.NotNil(rule, cfg) var headers http.Header data, ok := cfg.Data().([]interface{}) if !ok { diff --git a/internal/rule/instrumentation.go b/internal/rule/instrumentation.go index 8afe2367..82499d2f 100644 --- a/internal/rule/instrumentation.go +++ b/internal/rule/instrumentation.go @@ -16,7 +16,7 @@ type InstrumentationFace interface { } type HookFace interface { - Attach(prolog sqhook.PrologCallback) error + Attach(prologs ...sqhook.PrologCallback) error } type defaultInstrumentationImpl struct{} diff --git a/internal/rule/rule.go b/internal/rule/rule.go index 66cf88b1..efb9c043 100644 --- a/internal/rule/rule.go +++ b/internal/rule/rule.go @@ -21,6 +21,7 @@ package rule import ( "crypto/ecdsa" "io" + "sort" "github.com/sqreen/go-agent/internal/backend/api" "github.com/sqreen/go-agent/internal/metrics" @@ -35,7 +36,7 @@ type Engine struct { // at run time by atomically replacing a running rule. // TODO: write a test to check two HookFaces are correctly comparable // to find back a hook - hooks hookDescriptors + hooks hookDescriptorMap packID string enabled bool metricsEngine *metrics.Engine @@ -79,7 +80,7 @@ func (e *Engine) PackID() string { // them by atomically modifying the hooks, and removing what is left. func (e *Engine) SetRules(packID string, rules []api.Rule) { // Create the new rule descriptors and replace the existing ones - var ruleDescriptors hookDescriptors + var ruleDescriptors hookDescriptorMap if len(rules) > 0 { e.logger.Debugf("security rules: loading rules from pack `%s`", packID) ruleDescriptors = newHookDescriptors(e, rules) @@ -87,7 +88,7 @@ func (e *Engine) SetRules(packID string, rules []api.Rule) { e.setRules(packID, ruleDescriptors) } -func (e *Engine) setRules(packID string, descriptors hookDescriptors) { +func (e *Engine) setRules(packID string, descriptors hookDescriptorMap) { // Firstly update already enabled hookpoints with their new callbacks in order // to avoid having a blank moment without any callback set. This case happens // when a rule is updated. @@ -96,7 +97,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { if e.enabled { // Attach the callback to the hook, possibly overwriting the previous one. e.logger.Debugf("security rules: attaching callback to `%s`", hook) - err := hook.Attach(descr.callback) + err := hook.Attach(descr.callbacks...) if err != nil { e.logger.Error(sqerrors.Wrapf(err, "security rules: could not attach the prolog callback to `%s`", hook)) continue @@ -135,11 +136,11 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // newHookDescriptors walks the list of received rules and creates the map of // hook descriptors indexed by their hook pointer. A hook descriptor contains // all it takes to enable and disable rules at run time. -func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { +func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptorMap { logger := e.logger // Create and configure the list of callbacks according to the given rules - var hookDescriptors = make(hookDescriptors) + var hookDescriptors = make(hookDescriptorMap) for i := len(rules) - 1; i >= 0; i-- { r := rules[i] // Verify the signature @@ -168,6 +169,8 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { continue } + // Create the prolog callback + var prolog sqhook.PrologCallback switch hookpoint.Strategy { case "", "native": cfg, err := newNativeCallbackConfig(&r) @@ -176,26 +179,23 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { continue } - prolog, err := NewNativeCallback(hookpoint.Callback, callbackContext, cfg) + prolog, err = NewNativeCallback(hookpoint.Callback, callbackContext, cfg) if err != nil { logger.Error(sqerrors.Wrapf(err, "security rules: rule `%s`: callback constructor", r.Name)) continue } - // Create the descriptor with everything required to be able to enable or - // disable it afterwards. - hookDescriptors.Set(hook, prolog) case "reflected": - prolog, err := NewReflectedCallback(hookpoint.Callback, callbackContext, &r) + prolog, err = NewReflectedCallback(hookpoint.Callback, callbackContext, &r) if err != nil { logger.Error(sqerrors.Wrapf(err, "security rules: rule `%s`: callback constructor", r.Name)) continue } - // Create the descriptor with everything required to be able to enable or - // disable it afterwards. - hookDescriptors.Set(hook, prolog) } + // Create the descriptor with everything required to be able to enable or + // disable it afterwards. + hookDescriptors.Add(hook, prolog, r.Priority) } // Nothing in the end if len(hookDescriptors) == 0 { @@ -207,9 +207,8 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { // Enable the hooks of the ongoing configured rules. func (e *Engine) Enable() { for hook, descr := range e.hooks { - prolog := descr.callback e.logger.Debugf("security rules: attaching callback to hook `%s`", hook) - if err := hook.Attach(prolog); err != nil { + if err := hook.Attach(descr.callbacks...); err != nil { e.logger.Error(sqerrors.Wrapf(err, "security rules: could not attach the callback to hook `%v`", hook)) } } @@ -235,23 +234,66 @@ func (e *Engine) Count() int { return len(e.hooks) } -type callbackWrapper struct { - callback sqhook.PrologCallback -} +type ( + hookDescriptorMap map[HookFace]hookDescriptor -func (c callbackWrapper) Close() error { - if closer, ok := c.callback.(io.Closer); ok { - return closer.Close() + hookDescriptor struct { + priorities []int + callbacks []sqhook.PrologCallback + closers []io.Closer + } +) + +func (m hookDescriptorMap) Add(hook HookFace, callback sqhook.PrologCallback, priority int) { + d, exists := m[hook] + closer, _ := callback.(io.Closer) + + if !exists { + // First insertion + var closers []io.Closer + if closer != nil { + closers = []io.Closer{closer} + } + m[hook] = hookDescriptor{ + priorities: []int{priority}, + callbacks: []sqhook.PrologCallback{callback}, + closers: closers, + } + return } - return nil -} -type hookDescriptors map[HookFace]callbackWrapper + // Not the first insertion. + // Look for the callback position i per ascending priority order + i := sort.Search(len(d.priorities), func(i int) bool { + return d.priorities[i] > priority + }) -func (m hookDescriptors) Set(hook HookFace, prolog sqhook.PrologCallback) { - m[hook] = callbackWrapper{prolog} + // Update the list of priorities + d.priorities = append(d.priorities, 0) + copy(d.priorities[i+1:], d.priorities[i:]) + d.priorities[i] = priority + + // Update the list of closers + if closer != nil { + d.closers = append(d.closers, closer) + } + + // Update the list of callbacks + d.callbacks = append(d.callbacks, nil) + copy(d.callbacks[i+1:], d.callbacks[i:]) + d.callbacks[i] = callback + + // Update the hook descriptor map entry with the new value + m[hook] = d } -func (m hookDescriptors) Get(hook HookFace) callbackWrapper { - return m[hook] +func (d hookDescriptor) Close() error { + var errs sqerrors.ErrorCollection + for _, c := range d.closers { + err := c.Close() + if err != nil { + errs.Add(err) + } + } + return errs.ToError() } diff --git a/internal/rule/rule_test.go b/internal/rule/rule_test.go index bc86122a..5d589164 100644 --- a/internal/rule/rule_test.go +++ b/internal/rule/rule_test.go @@ -34,6 +34,8 @@ func (i *instrumentationMockup) Health(expectedVersion string) error { type hookMockup struct{ mock.Mock } +var _ rule.HookFace = &hookMockup{} + func (i *instrumentationMockup) Find(symbol string) (rule.HookFace, error) { res := i.Called(symbol) err := res.Error(1) @@ -47,12 +49,22 @@ func (i *instrumentationMockup) ExpectFind(symbol string) *mock.Call { return i.On("Find", symbol) } -func (h *hookMockup) Attach(prolog sqhook.PrologCallback) error { - return h.Called(prolog).Error(0) +func (h *hookMockup) Attach(prologs ...sqhook.PrologCallback) error { + return h.Called(prologs).Error(0) } -func (h *hookMockup) ExpectAttach(prolog interface{}) *mock.Call { - return h.On("Attach", prolog) +func (h *hookMockup) ExpectAttach(prologs ...interface{}) *mock.Call { + var args interface{} + if l := len(prologs); l == 1 && prologs[0] == mock.Anything { + args = prologs[0] + } else { + prologArgs := make([]sqhook.PrologCallback, l) + for i, p := range prologs { + prologArgs[i] = p + } + args = prologArgs + } + return h.On("Attach", args) } func (h *hookMockup) PrologFuncType() reflect.Type { diff --git a/internal/rule/rule_unit_test.go b/internal/rule/rule_unit_test.go new file mode 100644 index 00000000..a74f1f2f --- /dev/null +++ b/internal/rule/rule_unit_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2016 - 2020 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule + +import ( + "io" + "testing" + + "github.com/sqreen/go-agent/internal/sqlib/sqhook" + "github.com/stretchr/testify/require" +) + +type hookMockup struct{} + +func (h hookMockup) Attach(...sqhook.PrologCallback) error { + panic("should not be called") + // TODO: better API to avoid that? the map only needs a "comparable" key and + // doesn't matter about the hook interface. +} + +func TestHookDescriptors(t *testing.T) { + // Not actual callbacks but enough for this unit test. + // We need to use distinct types to correctly check the ordering. + + t.Run("multiple callbacks having the same priority", func(t *testing.T) { + var m = hookDescriptorMap{} + key := hookMockup{} + m.Add(key, 1, 1) + m.Add(key, 2, 1) + m.Add(key, 3, 1) + m.Add(key, 4, 1) + d := m[key] + require.Equal(t, []int{1, 1, 1, 1}, d.priorities) + require.Equal(t, []sqhook.PrologCallback{1, 2, 3, 4}, d.callbacks) + require.Nil(t, d.closers) + }) + + t.Run("multiple callbacks having distinct priorities", func(t *testing.T) { + var m = hookDescriptorMap{} + key := hookMockup{} + + m.Add(key, 3, 2) + m.Add(key, 5, 3) + m.Add(key, 4, 2) + m.Add(key, 1, 1) + m.Add(key, 6, 3) + m.Add(key, 2, 1) + d := m[key] + require.Equal(t, []int{1, 1, 2, 2, 3, 3}, d.priorities) + require.Equal(t, []sqhook.PrologCallback{1, 2, 3, 4, 5, 6}, d.callbacks) + require.Nil(t, d.closers) + }) + + t.Run("multiple callbacks with close methods", func(t *testing.T) { + var m = hookDescriptorMap{} + key := hookMockup{} + m.Add(key, myFakeCallback(7), 10) + m.Add(key, 3, 2) + m.Add(key, myFakeCallback(1), 1) + m.Add(key, 2, 1) + m.Add(key, myFakeCallback(5), 3) + m.Add(key, 4, 2) + m.Add(key, 6, 3) + + d := m[key] + require.Equal(t, []int{1, 1, 2, 2, 3, 3, 10}, d.priorities) + require.Equal(t, []sqhook.PrologCallback{myFakeCallback(1), 2, 3, 4, myFakeCallback(5), 6, myFakeCallback(7)}, d.callbacks) + require.Equal(t, []io.Closer{myFakeCallback(7), myFakeCallback(1), myFakeCallback(5)}, d.closers) + }) +} + +type myFakeCallback int + +func (m myFakeCallback) Close() error { return nil } diff --git a/internal/sqlib/sqassert/assert.go b/internal/sqlib/sqassert/assert.go index 62abc355..305d9d09 100644 --- a/internal/sqlib/sqassert/assert.go +++ b/internal/sqlib/sqassert/assert.go @@ -20,9 +20,11 @@ func NoError(err error) { } } -func NotNil(v interface{}) { - if v == nil { - doPanic(sqerrors.New("sqassert: unexpected nil value")) +func NotNil(v ...interface{}) { + for _, v := range v { + if v == nil { + doPanic(sqerrors.New("sqassert: unexpected nil value")) + } } } diff --git a/internal/sqlib/sqassert/assert_disabled.go b/internal/sqlib/sqassert/assert_disabled.go index f0ca0426..96fa42f3 100644 --- a/internal/sqlib/sqassert/assert_disabled.go +++ b/internal/sqlib/sqassert/assert_disabled.go @@ -6,6 +6,6 @@ package sqassert -func True(bool) {} -func NoError(error) {} -func NotNil(interface{}) {} +func True(bool) {} +func NoError(error) {} +func NotNil(...interface{}) {} diff --git a/internal/sqlib/sqerrors/errors.go b/internal/sqlib/sqerrors/errors.go index 68d2e5c1..a1d06dad 100644 --- a/internal/sqlib/sqerrors/errors.go +++ b/internal/sqlib/sqerrors/errors.go @@ -6,6 +6,7 @@ package sqerrors import ( "fmt" + "strings" "time" "github.com/pkg/errors" @@ -171,3 +172,26 @@ func Timestamp(err error) (t time.Time, ok bool) { } return time.Time{}, false } + +type ErrorCollection []error + +func (c ErrorCollection) Error() string { + var s strings.Builder + s.WriteString("multiple errors occurred:") + for i, e := range c { + fmt.Fprintf(&s, " (error %d) %s;", i+1, e.Error()) + } + // Return the build string without the trailing `;` + return s.String()[:s.Len()-1] +} + +func (c *ErrorCollection) Add(e error) { + *c = append(*c, e) +} + +func (c ErrorCollection) ToError() error { + if len(c) == 0 { + return nil + } + return c +} diff --git a/internal/sqlib/sqerrors/errors_test.go b/internal/sqlib/sqerrors/errors_test.go index f29268d2..a626afd9 100644 --- a/internal/sqlib/sqerrors/errors_test.go +++ b/internal/sqlib/sqerrors/errors_test.go @@ -46,3 +46,12 @@ func TestWithInfo(t *testing.T) { require.Equal(t, info, got) }) } + +func TestErrorCollection(t *testing.T) { + var errs sqerrors.ErrorCollection + errs.Add(errors.New("error 1")) + errs.Add(errors.New("error 2")) + errs.Add(errors.New("error 3")) + errs.Add(errors.New("error 4")) + require.Equal(t, "multiple errors occurred: (error 1) error 1; (error 2) error 2; (error 3) error 3; (error 4) error 4", errs.Error()) +} diff --git a/internal/sqlib/sqhook/hook.go b/internal/sqlib/sqhook/hook.go index eb876648..9083c9c9 100644 --- a/internal/sqlib/sqhook/hook.go +++ b/internal/sqlib/sqhook/hook.go @@ -186,7 +186,7 @@ func normalizedHookID(symbol string) string { // add creates the hook object for function `fn`, adds it to the find map and // returns it. It returns an error if it is not possible. -func (t symbolIndexType) add(fn, prologVar interface{}) (*Hook, error) { +func (t symbolIndexType) add(fn, prologVar interface{}) (h *Hook, err error) { // Check fn is a non-nil function value if fn == nil { return nil, sqerrors.New("unexpected function argument value `nil`") @@ -194,18 +194,25 @@ func (t symbolIndexType) add(fn, prologVar interface{}) (*Hook, error) { fnValue := reflect.ValueOf(fn) fnType := fnValue.Type() if fnType.Kind() != reflect.Func { - return nil, sqerrors.Errorf("unexpected function argument type: expecting a function value but got `%v`", fn) + return nil, sqerrors.Errorf("unexpected function argument type: expecting a function value but got `%T`", fn) } // Get the symbol name symbol := runtime.FuncForPC(fnValue.Pointer()).Name() if symbol == "" { - return nil, sqerrors.Errorf("could not read the symbol name of function `%#v`", fn) + return nil, sqerrors.Errorf("could not read the symbol name of function `%T`", fn) } // Unvendor it so that it is not prefixed by `/vendor/` symbol = sqgo.Unvendor(symbol) + // Use the symbol name for better error messages + defer func() { + if err != nil { + err = sqerrors.Wrapf(err, "symbol `%s`", symbol) + } + }() + // The hook may have been already added by a previous lookup if hook, exists := t[symbol]; exists { return hook, nil @@ -241,29 +248,44 @@ func (h *Hook) String() string { // Attach atomically attaches a prolog function to the hook. The hook can be // disabled with a `nil` prolog value. -func (h *Hook) Attach(prolog PrologCallback) error { +func (h *Hook) Attach(prologs ...PrologCallback) error { addr := h.prologVarAddr - if prolog == nil { + if l := len(prologs); l == 0 || (l == 1 && prologs[0] == nil) { // Disable atomic.StorePointer(addr, nil) - // TODO: should we check if the attach cb has a Close() method? return nil } -loop: - for { - switch actual := prolog.(type) { - case ReflectedPrologCallback: - prolog = makePrologCallback(h, actual) - case PrologCallbackGetter: - prolog = actual.PrologCallback() - default: - break loop + prologCallbacks := make([]PrologCallback, len(prologs)) + for i, prolog := range prologs { + // Loop until the prolog type is not one of the above + loop: + for { + switch actual := prolog.(type) { + case ReflectedPrologCallback: + prolog = makePrologCallback(h, actual) + case PrologCallbackGetter: + prolog = actual.PrologCallback() + default: + // Final type + break loop + } + } + + if h.prologFuncType != reflect.TypeOf(prolog) { + return sqerrors.Errorf("unexpected prolog type for hook `%s`: got `%T`, wanted `%s`", h, prolog, h.prologFuncType) } + + prologCallbacks[i] = prolog } - if h.prologFuncType != reflect.TypeOf(prolog) { - return sqerrors.Errorf("unexpected prolog type for hook `%s`: got `%T`, wanted `%s`", h, prolog, h.prologFuncType) + // Create the prolog out of the prologCallbacks + var prolog PrologCallback + if l := len(prologCallbacks); l == 1 { + prolog = prologCallbacks[0] + } else { + // Create a dynamic function calling the prolog + prolog = makeMultiPrologCallback(h, prologCallbacks) } // Create a value having type "pointer to the prolog function" @@ -275,6 +297,39 @@ loop: return nil } +func makeMultiPrologCallback(h *Hook, prologs []PrologCallback) PrologCallback { + return makePrologCallback(h, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + safeCallErr := sqsafe.Call(func() error { + epilogs := make([]reflect.Value, 0, len(prologs)) + defer func() { + if len(epilogs) > 0 { + epilog = func(results []reflect.Value) { + for _, epilog := range epilogs { + epilog.Call(results) + } + } + } + }() + for _, prolog := range prologs { + prologValue := reflect.ValueOf(prolog) + results := prologValue.Call(params) + if r0 := results[0]; !r0.IsNil() { + epilogs = append(epilogs, r0) + } + if r1 := results[1]; !r1.IsNil() { + err = r1.Interface().(error) + return nil + } + } + return nil + }) + if safeCallErr != nil { + // TODO: log this error once + } + return epilog, err + }) +} + func makePrologCallback(h *Hook, prolog ReflectedPrologCallback) PrologCallback { prologFuncType := h.prologFuncType epilogFuncType := h.prologFuncType.Out(0) diff --git a/internal/sqlib/sqhook/hook_test.go b/internal/sqlib/sqhook/hook_test.go index cf7f21fc..18f1ec9b 100644 --- a/internal/sqlib/sqhook/hook_test.go +++ b/internal/sqlib/sqhook/hook_test.go @@ -36,14 +36,16 @@ func (example) myMethod() {} func (example) MyExportedMethod() {} func (*example) myMethodWithPointerReceiver() {} -func myFunction(_ int, _ string, _ bool) (float32, error) { return 0, nil } -func MyExportedFunction(_ int, _ string, _ bool) error { return nil } +func myFunction(_ int, _ string, _ bool) (float32, error) { return 0, nil } +func myFunction2(_ int, _ string, _ bool) (float32, error) { return 0, nil } +func MyExportedFunction(_ int, _ string, _ bool) error { return nil } var ( MyMethodSymbol = runtime.FuncForPC(reflect.ValueOf(example.myMethod).Pointer()).Name() MyMethodWithPointerRecvSymbol = runtime.FuncForPC(reflect.ValueOf((*example).myMethodWithPointerReceiver).Pointer()).Name() MyExportedMethodSymbol = runtime.FuncForPC(reflect.ValueOf(example.MyExportedMethod).Pointer()).Name() MyFunctionSymbol = runtime.FuncForPC(reflect.ValueOf(myFunction).Pointer()).Name() + MyFunction2Symbol = runtime.FuncForPC(reflect.ValueOf(myFunction2).Pointer()).Name() MyExportedFunctionSymbol = runtime.FuncForPC(reflect.ValueOf(MyExportedFunction).Pointer()).Name() ) @@ -53,6 +55,7 @@ var sortedSymbols = []string{ // Sorted by normalized name MyMethodSymbol, MyMethodWithPointerRecvSymbol, MyFunctionSymbol, + MyFunction2Symbol, } var expectedSymbols = map[string]internal.HookDescriptorFuncType{ @@ -84,6 +87,13 @@ var expectedSymbols = map[string]internal.HookDescriptorFuncType{ } }, + MyFunction2Symbol: func(d *internal.HookDescriptorType) { + *d = internal.HookDescriptorType{ + Func: myFunction2, + PrologVar: &MyFunctionProlog, + } + }, + MyExportedFunctionSymbol: func(d *internal.HookDescriptorType) { *d = internal.HookDescriptorType{ Func: MyExportedFunction, @@ -135,7 +145,7 @@ func TestGoAssumptions(t *testing.T) { require.Equal(t, (*sqhook.PrologCallback)(nil), cb) }) - t.Run("the first argument of a myMethod is the method receiver", func(t *testing.T) { + t.Run("the first argument of a method is the method receiver", func(t *testing.T) { require.Equal(t, reflect.TypeOf(example{}).Name(), reflect.TypeOf(example.myMethod).In(0).Name()) }) @@ -168,6 +178,14 @@ func TestFind(t *testing.T) { } } +type prologCallbackGetter struct { + prolog sqhook.PrologCallback +} + +func (p prologCallbackGetter) PrologCallback() sqhook.PrologCallback { + return p.prolog +} + func TestAttach(t *testing.T) { for _, tc := range []struct { Symbol string @@ -255,10 +273,16 @@ func TestAttach(t *testing.T) { return []reflect.Value{{}, {}} // not used by the test }) - checkProlog := func(t *testing.T) { + checkPrologAddr := func(t *testing.T, expected uintptr) { + // Read barrier using the prolog var + _ = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(reflect.ValueOf(descr.PrologVar).Pointer()))) + require.Equal(t, expected, reflect.ValueOf(descr.PrologVar).Elem().Elem().Pointer()) + } + + checkPrologAddrNotNil := func(t *testing.T) { // Read barrier using the prolog var _ = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(reflect.ValueOf(descr.PrologVar).Pointer()))) - require.Equal(t, expectedProlog.Pointer(), reflect.ValueOf(descr.PrologVar).Elem().Elem().Pointer()) + require.NotZero(t, reflect.ValueOf(descr.PrologVar).Elem().Elem().Pointer()) } t.Run(tc.Symbol, func(t *testing.T) { @@ -269,21 +293,84 @@ func TestAttach(t *testing.T) { hook, err := sqhook.Find(tc.Symbol) require.NoError(t, err) require.NotNil(t, hook) - // Attach the expected prolog function - err = hook.Attach(expectedProlog.Interface()) - require.NoError(t, err) - // Read back the prolog variable - checkProlog(t) + + t.Run("native prolog callback", func(t *testing.T) { + // Attach the expected prolog function + err = hook.Attach(expectedProlog.Interface()) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddr(t, expectedProlog.Pointer()) + }) + + t.Run("reflected prolog callback", func(t *testing.T) { + var reflected sqhook.ReflectedPrologCallback = func(params []reflect.Value) (epilog sqhook.ReflectedEpilogCallback, err error) { + return nil, nil + } + // Attach the expected prolog function + err = hook.Attach(reflected) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) + + t.Run("prolog callback getter", func(t *testing.T) { + t.Run("returning a native prolog", func(t *testing.T) { + // Attach the expected prolog function + err = hook.Attach(prologCallbackGetter{prolog: expectedProlog.Interface()}) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) + + t.Run("returning a reflected prolog", func(t *testing.T) { + var reflected sqhook.ReflectedPrologCallback = func(params []reflect.Value) (epilog sqhook.ReflectedEpilogCallback, err error) { + return nil, nil + } + // Attach the expected prolog function + err = hook.Attach(prologCallbackGetter{ + prolog: reflected, + }) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) + }) + + t.Run("multiple prolog callbacks", func(t *testing.T) { + var reflected sqhook.ReflectedPrologCallback = func(params []reflect.Value) (epilog sqhook.ReflectedEpilogCallback, err error) { + return nil, nil + } + // Attach the expected prolog function + native := expectedProlog.Interface() + err = hook.Attach(reflected, native, reflected, native, reflected, native, prologCallbackGetter{prolog: reflected}, prologCallbackGetter{prolog: expectedProlog.Interface()}) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) }) + t.Run("not expected prolog types", func(t *testing.T) { + hook, err := sqhook.Find(tc.Symbol) + require.NoError(t, err) + require.NotNil(t, hook) + require.NoError(t, hook.Attach(nil)) + for _, invalidProlog := range tc.InvalidPrologs { invalidProlog := invalidProlog t.Run(fmt.Sprintf("%T", invalidProlog), func(t *testing.T) { - hook, err := sqhook.Find(tc.Symbol) - require.NoError(t, err) - require.NotNil(t, hook) err = hook.Attach(invalidProlog) require.Error(t, err) + //checkPrologAddr(t, 0) + }) + + t.Run(fmt.Sprintf("%T along with the expected prolog callback", invalidProlog), func(t *testing.T) { + err = hook.Attach(expectedProlog.Interface(), invalidProlog) + require.Error(t, err) + //checkPrologAddr(t, 0) + + err = hook.Attach(invalidProlog, expectedProlog.Interface()) + require.Error(t, err) + //checkPrologAddr(t, 0) }) } }) diff --git a/internal/sqlib/sqhook/hook_unit_test.go b/internal/sqlib/sqhook/hook_unit_test.go new file mode 100644 index 00000000..37ab114a --- /dev/null +++ b/internal/sqlib/sqhook/hook_unit_test.go @@ -0,0 +1,553 @@ +// Copyright (c) 2016 - 2020 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package sqhook + +import ( + "errors" + "reflect" + "testing" + + fuzz "github.com/google/gofuzz" + "github.com/sqreen/go-agent/internal/sqlib/sqhook/internal" + "github.com/sqreen/go-agent/tools/testlib" + "github.com/stretchr/testify/require" +) + +func myFunction(int, string, bool) (float32, error) { return 0, nil } + +func TestInstrumentationError(t *testing.T) { + // Test that we would catch instrumentation mistakes - which should never + // happen + var myFunctionPrologVar *func(*int, *string, *bool) (func(*float32, *error), error) + + for _, tc := range []struct { + Name string + Fn interface{} + PrologVar interface{} + }{ + { + Name: "nil function", + Fn: nil, + PrologVar: &myFunctionPrologVar, + }, + + { + Name: "not a function", + Fn: 33, + PrologVar: &myFunctionPrologVar, + }, + + { + Name: "nil prolog var", + Fn: myFunction, + PrologVar: nil, + }, + + { + Name: "prolog var is not a pointer", + Fn: myFunction, + PrologVar: myFunctionPrologVar, + }, + } { + t.Run(tc.Name, func(t *testing.T) { + h, err := symbolIndexType{}.add(tc.Fn, tc.PrologVar) + require.Error(t, err) + require.Nil(t, h) + }) + } +} + +func TestReflectedCallback(t *testing.T) { + t.Run("", func(t *testing.T) { + type epilogType = func() + type prologType = func() (epilogType, error) + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + require.Len(t, params, 0) + return nil, nil + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + epilog, err := prolog() + require.NoError(t, err) + require.Nil(t, epilog) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func() + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return nil, nil + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + fuzz.New().Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.NoError(t, err) + require.Nil(t, epilog) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func() + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prologErr := errors.New("my error") + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return nil, prologErr + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + fuzz.New().Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.Error(t, err) + require.Equal(t, prologErr, err) + require.Nil(t, epilog) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func(int, bool, string, float64, map[string]bool) + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs, epilogArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prologErr := errors.New("my error") + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return func(results []reflect.Value) { + argValues := reflect.ValueOf(epilogArgs) + argTypes := argValues.Type() + require.Len(t, results, argTypes.NumField()) + for i := range results { + require.Equal(t, argTypes.Field(i).Type, results[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), results[i].Interface()) + } + }, prologErr + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + f := fuzz.New() + + f.Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.Error(t, err) + require.Equal(t, prologErr, err) + + require.NotNil(t, epilog) + f.Fuzz(&epilogArgs) + epilog(epilogArgs.A, epilogArgs.B, epilogArgs.C, epilogArgs.D, epilogArgs.E) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func(int, bool, string, float64, map[string]bool) + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs, epilogArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return func(results []reflect.Value) { + argValues := reflect.ValueOf(epilogArgs) + argTypes := argValues.Type() + require.Len(t, results, argTypes.NumField()) + for i := range results { + require.Equal(t, argTypes.Field(i).Type, results[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), results[i].Interface()) + } + }, nil + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + f := fuzz.New() + + f.Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.NoError(t, err) + + require.NotNil(t, epilog) + f.Fuzz(&epilogArgs) + epilog(epilogArgs.A, epilogArgs.B, epilogArgs.C, epilogArgs.D, epilogArgs.E) + }) +} + +func TestMultiCallback(t *testing.T) { + type epilogType = func(byte, rune, []string) + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var ( + prologArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + epilogArgs struct { + A byte + B rune + C []string + } + hook = &Hook{prologFuncType: reflect.TypeOf(prologType(nil))} + f = fuzz.New() + order []int + ) + + makePrologFunc := func(t *testing.T, expectedOrder int, epilog epilogType, prologErr error) prologType { + return func(a int, b bool, c string, d float64, e map[string]bool) (epilogType, error) { + require.Equal(t, prologArgs.A, a) + require.Equal(t, prologArgs.B, b) + require.Equal(t, prologArgs.C, c) + require.Equal(t, prologArgs.D, d) + require.Equal(t, prologArgs.E, e) + order = append(order, expectedOrder) + return epilog, prologErr + } + } + + makeEpilogFunc := func(t *testing.T, expectedOrder int) epilogType { + return func(a byte, b rune, c []string) { + require.Equal(t, epilogArgs.A, a) + require.Equal(t, epilogArgs.B, b) + require.Equal(t, epilogArgs.C, c) + order = append(order, expectedOrder) + } + } + + for _, tc := range []struct { + Prologs []PrologCallback + ExpectedOrder []int + ExpectedError error + }{ + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, makeEpilogFunc(t, 14), nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), nil), + makePrologFunc(t, 6, makeEpilogFunc(t, 17), nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, makeEpilogFunc(t, 19), nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), nil), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 18, 20, 21}, + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), errors.New("my error")), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 11, 12, 13, 15, 16}, + ExpectedError: errors.New("my error"), + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), errors.New("my error 1")), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), errors.New("my error 2")), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 11}, + ExpectedError: errors.New("my error 1"), + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), nil), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), errors.New("my error")), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 18, 20, 21}, + ExpectedError: errors.New("my error"), + }, + } { + tc := tc + t.Run("", func(t *testing.T) { + prolog, ok := makeMultiPrologCallback(hook, tc.Prologs).(prologType) + require.True(t, ok) + require.NotNil(t, prolog) + + f.Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + + if tc.ExpectedError != nil { + require.Error(t, err) + require.Equal(t, tc.ExpectedError, err) + } else { + require.NoError(t, err) + } + + require.NotNil(t, epilog) + + epilog(epilogArgs.A, epilogArgs.B, epilogArgs.C) + + require.Equal(t, tc.ExpectedOrder, order) + order = nil // TODO: avoid this test side-effect... + }) + } +} + +func TestHookTableLookup(t *testing.T) { + t.Run("nil", func(t *testing.T) { + myIndex := symbolIndexType{} + found, err := hookTableLookup(nil, testlib.RandUTF8String(), myIndex) + require.NoError(t, err) + require.Nil(t, found) + }) + + t.Run("empty", func(t *testing.T) { + myIndex := symbolIndexType{} + myTable := internal.HookTableType{} + found, err := hookTableLookup(myTable, testlib.RandUTF8String(), myIndex) + require.NoError(t, err) + require.Nil(t, found) + }) + + t.Run("having instrumentation errors", func(t *testing.T) { + myIndex := symbolIndexType{} + for _, tc := range []internal.HookTableType{ + { + func(d *internal.HookDescriptorType) { + // Nil values + *d = internal.HookDescriptorType{Func: nil, PrologVar: nil} + }, + }, + + { + func(d *internal.HookDescriptorType) { + // Nil Func value - Non-nil prolog var + var prologVar *func() + *d = internal.HookDescriptorType{Func: nil, PrologVar: &prologVar} + }, + }, + } { + tc := tc + t.Run("", func(t *testing.T) { + found, err := hookTableLookup(tc, testlib.RandUTF8String(), myIndex) + require.Error(t, err) + require.Nil(t, found) + }) + } + }) +} + +func TestPrologVarValidation(t *testing.T) { + for _, tc := range []struct { + fn, prolog interface{} + shouldSucceed bool + }{ + { + fn: (func())(nil), + prolog: (func() (func(), error))(nil), + shouldSucceed: true, + }, + + { // wrong arg count + fn: (func())(nil), + prolog: (func(*int) (func(), error))(nil), + }, + + { // wrong prolog arg type: should be *int + fn: (func(int))(nil), + prolog: (func(int) (func(), error))(nil), + }, + + { // wrong prolog arg type: should be *int + fn: (func(int))(nil), + prolog: (func(*int) (func(), error))(nil), + shouldSucceed: true, + }, + + { // wrong return count + fn: (func(int))(nil), + prolog: (func(*int) error)(nil), + }, + + { // wrong return type: wrong prolog type + fn: (func(int))(nil), + prolog: (func(*int) (func(string), error))(nil), + }, + + { // wrong return type: wrong error type + fn: (func(int))(nil), + prolog: (func(*int) (func(), bool))(nil), + }, + + { // wrong return count + fn: (func(int))(nil), + prolog: (func(*int) (func(), error, bool))(nil), + }, + + { // wrong prolog type: wrong arg count + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}, *error, *int), error))(nil), + }, + + { // wrong prolog type: wrong arg count + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}), error))(nil), + }, + + { // wrong prolog type: wrong arg count + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(), error))(nil), + }, + + { // wrong prolog type: wrong arg types + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(interface{}, interface{}, interface{}), error))(nil), + }, + + { // wrong prolog type: wrong arg types + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(chan struct{}, *error), error))(nil), + }, + + { + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}, *error), error))(nil), + shouldSucceed: true, + }, + + { // variadic func + fn: (func(...int))(nil), + prolog: (func(*[]int) (func(), error))(nil), + shouldSucceed: true, + }, + } { + tc := tc + t.Run("unexpected signatures", func(t *testing.T) { + fnType := reflect.TypeOf(tc.fn) + prologVarType := reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(tc.prolog))) + err := validatePrologVar(fnType, prologVarType) + if tc.shouldSucceed { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} diff --git a/internal/sqlib/sqhook/validation_test.go b/internal/sqlib/sqhook/validation_test.go deleted file mode 100644 index 33a319c2..00000000 --- a/internal/sqlib/sqhook/validation_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. -// Please refer to our terms for more information: -// https://www.sqreen.io/terms.html - -package sqhook - -import ( - "reflect" - "testing" - - "github.com/sqreen/go-agent/internal/sqlib/sqhook/internal" - "github.com/sqreen/go-agent/tools/testlib" - "github.com/stretchr/testify/require" -) - -func TestHookTableLookup(t *testing.T) { - t.Run("nil", func(t *testing.T) { - myIndex := symbolIndexType{} - found, err := hookTableLookup(nil, testlib.RandUTF8String(), myIndex) - require.NoError(t, err) - require.Nil(t, found) - }) - - t.Run("empty", func(t *testing.T) { - myIndex := symbolIndexType{} - myTable := internal.HookTableType{} - found, err := hookTableLookup(myTable, testlib.RandUTF8String(), myIndex) - require.NoError(t, err) - require.Nil(t, found) - }) - - t.Run("having instrumentation errors", func(t *testing.T) { - myIndex := symbolIndexType{} - for _, tc := range []internal.HookTableType{ - { - func(d *internal.HookDescriptorType) { - // Nil values - *d = internal.HookDescriptorType{Func: nil, PrologVar: nil} - }, - }, - - { - func(d *internal.HookDescriptorType) { - // Nil Func value - Non-nil prolog var - var prologVar *func() - *d = internal.HookDescriptorType{Func: nil, PrologVar: &prologVar} - }, - }, - } { - tc := tc - t.Run("", func(t *testing.T) { - found, err := hookTableLookup(tc, testlib.RandUTF8String(), myIndex) - require.Error(t, err) - require.Nil(t, found) - }) - } - }) -} - -func TestPrologVarValidation(t *testing.T) { - for _, tc := range []struct { - fn, prolog interface{} - shouldSucceed bool - }{ - { - fn: (func())(nil), - prolog: (func() (func(), error))(nil), - shouldSucceed: true, - }, - - { // wrong arg count - fn: (func())(nil), - prolog: (func(*int) (func(), error))(nil), - }, - - { // wrong prolog arg type: should be *int - fn: (func(int))(nil), - prolog: (func(int) (func(), error))(nil), - }, - - { // wrong prolog arg type: should be *int - fn: (func(int))(nil), - prolog: (func(*int) (func(), error))(nil), - shouldSucceed: true, - }, - - { // wrong return count - fn: (func(int))(nil), - prolog: (func(*int) error)(nil), - }, - - { // wrong return type: wrong prolog type - fn: (func(int))(nil), - prolog: (func(*int) (func(string), error))(nil), - }, - - { // wrong return type: wrong error type - fn: (func(int))(nil), - prolog: (func(*int) (func(), bool))(nil), - }, - - { // wrong return count - fn: (func(int))(nil), - prolog: (func(*int) (func(), error, bool))(nil), - }, - - { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}, *error, *int), error))(nil), - }, - - { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}), error))(nil), - }, - - { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(), error))(nil), - }, - - { // wrong prolog type: wrong arg types - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(interface{}, interface{}, interface{}), error))(nil), - }, - - { // wrong prolog type: wrong arg types - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(chan struct{}, *error), error))(nil), - }, - - { - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}, *error), error))(nil), - shouldSucceed: true, - }, - - { // variadic func - fn: (func(...int))(nil), - prolog: (func(*[]int) (func(), error))(nil), - shouldSucceed: true, - }, - } { - tc := tc - t.Run("unexpected signatures", func(t *testing.T) { - fnType := reflect.TypeOf(tc.fn) - prologVarType := reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(tc.prolog))) - err := validatePrologVar(fnType, prologVarType) - if tc.shouldSucceed { - require.NoError(t, err) - } else { - require.Error(t, err) - } - }) - } -}