Skip to content

Commit

Permalink
feat: new CallContext APIs (#17)
Browse files Browse the repository at this point in the history
* feat: new CallContext APIs

* fix TestGoMongo111

* consolidate dst node replication
  • Loading branch information
y1yang0 committed Jul 30, 2024
1 parent 2c080d6 commit 846a570
Show file tree
Hide file tree
Showing 16 changed files with 411 additions and 139 deletions.
56 changes: 17 additions & 39 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,21 @@ package api
// of the original function call. Modification of the Params and ReturnVals will
// affect the original function call thus should be used with caution.

type CallContext struct {
Params []interface{} // Address of parameters of original function
ReturnVals []interface{} // Address of return values of original function
Data interface{} // User defined data
SkipCall bool // Skip the original function call if set to true
}

func (ctx *CallContext) SetSkipCall(skip bool) {
ctx.SkipCall = skip
}

func (ctx *CallContext) SetData(data interface{}) {
ctx.Data = data
}

func (ctx *CallContext) GetData() interface{} {
return ctx.Data
}

func (ctx *CallContext) SetKeyData(key, val string) {
if ctx.Data == nil {
ctx.Data = make(map[string]string)
}
ctx.Data.(map[string]string)[key] = val
}

func (ctx *CallContext) GetKeyData(key string) string {
if ctx.Data == nil {
return ""
}
return ctx.Data.(map[string]string)[key]
}

func (ctx *CallContext) HasKeyData(key string) bool {
if ctx.Data == nil {
return false
}
_, ok := ctx.Data.(map[string]string)[key]
return ok
type CallContext interface {
// Skip the original function call
SetSkipCall(bool)
// Check if the original function call should be skipped
IsSkipCall() bool
// Set the data field, can be used to pass information between OnEnter&OnExit
SetData(interface{})
// Get the data field, can be used to pass information between OnEnter&OnExit
GetData() interface{}
// Get the original function parameter at index idx
GetParam(idx int) interface{}
// Change the original function parameter at index idx
SetParam(idx int, val interface{})
// Get the original function return value at index idx
GetReturnVal(idx int) interface{}
// Change the original function return value at index idx
SetReturnVal(idx int, val interface{})
}
5 changes: 3 additions & 2 deletions pkg/rules/mongo/client_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
"context"
"errors"
"fmt"
"sync"

"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"sync"
)

var mongoInstrumenter = BuildMongoOtelInstrumenter()

func mongoOnEnter(call *mongo.CallContext, opts ...*options.ClientOptions) {
func mongoOnEnter(call mongo.CallContext, opts ...*options.ClientOptions) {
syncMap := sync.Map{}
for _, opt := range opts {
hosts := opt.Hosts
Expand Down
58 changes: 42 additions & 16 deletions pkg/rules/test/errors_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,66 @@ import (
"fmt"
)

func onEnterUnwrap(call *errors.CallContext, err error) {
func onEnterUnwrap(call errors.CallContext, err error) {
newErr := fmt.Errorf("wrapped: %w", err)
*(call.Params[0].(*error)) = newErr
call.SetParam(0, newErr)
}

func onExitUnwrap(call *errors.CallContext, err error) {
e := (*(call.Params[0].(*error))).(interface {
func onExitUnwrap(call errors.CallContext, err error) {
e := call.GetParam(0).(interface {
Unwrap() error
})
old := e.Unwrap()
fmt.Printf("old:%v\n", old)
}

func onEnterTestSkip(call *errors.CallContext) {
func onEnterTestSkip(call errors.CallContext) {
call.SetSkipCall(true)
}

func onExitTestSkipOnly(call *errors.CallContext, _ *int) {}
func onExitTestSkipOnly(call errors.CallContext, _ *int) {}

func onEnterTestSkipOnly(call *errors.CallContext) {}
func onEnterTestSkipOnly(call errors.CallContext) {}

func onEnterP11(call *errors.CallContext) {}
func onEnterP12(call *errors.CallContext) {}
func onEnterP11(call errors.CallContext) {}
func onEnterP12(call errors.CallContext) {}

func onExitP21(call *errors.CallContext) {}
func onExitP22(call *errors.CallContext) {}
func onExitP21(call errors.CallContext) {}
func onExitP22(call errors.CallContext) {}

func onEnterP31(call *errors.CallContext, arg1 int, arg2 bool, arg3 float64) {}
func onExitP31(call *errors.CallContext, arg1 int, arg2 bool, arg3 float64) {}
func onEnterP31(call errors.CallContext, arg1 int, arg2 bool, arg3 float64) {}
func onExitP31(call errors.CallContext, arg1 int, arg2 bool, arg3 float64) {}

func onEnterTestSkip2(call *errors.CallContext) {
func onEnterTestSkip2(call errors.CallContext) {
call.SetSkipCall(true)
}

func onExitTestSkip2(call *errors.CallContext, _ int) {
*(call.ReturnVals[0].(*int)) = 0x512
func onExitTestSkip2(call errors.CallContext, _ int) {
call.SetReturnVal(0, 0x512)
}

func onEnterTestGetSet(call errors.CallContext, arg1 int, arg2, arg3 bool, arg4 float64, arg5 string, arg6 interface{}, arg7, arg8 map[int]bool, arg9 chan int, arg10 []int) {
call.SetParam(0, 7632)
call.SetParam(1, arg2)
call.SetParam(2, arg3)
call.SetParam(3, arg4)
call.SetParam(4, arg5)
call.SetParam(5, arg6)
call.SetParam(6, arg7)
call.SetParam(7, arg8)
call.SetParam(8, arg9)
call.SetParam(9, arg10)
}

func onExitTestGetSet(call errors.CallContext, arg1 int, arg2 bool, arg3 bool, arg4 float64, arg5 string, arg6 interface{}, arg7 map[int]bool, arg8 map[int]bool, arg9 chan int, arg10 []int) {
call.SetReturnVal(0, arg1)
call.SetReturnVal(1, arg2)
call.SetReturnVal(2, arg3)
call.SetReturnVal(3, arg4)
call.SetReturnVal(4, arg5)
call.SetReturnVal(5, arg6)
call.SetReturnVal(6, arg7)
call.SetReturnVal(7, arg8)
call.SetReturnVal(8, arg9)
call.SetReturnVal(9, arg10)
}
27 changes: 14 additions & 13 deletions pkg/rules/test/fmt_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@ package test

import "fmt"

func OnExitPrintf1(call *fmt.CallContext, n int, err error) {
func OnExitPrintf1(call fmt.CallContext, n int, err error) {
println("Exiting hook1....")
*(call.ReturnVals[0].(*int)) = 1024
call.SetReturnVal(0, 1024)
v := call.GetData().(int)
println(v)
}

type any = interface{}

func OnEnterPrintf1(call *fmt.CallContext, format string, arg ...any) {
func OnEnterPrintf1(call fmt.CallContext, format string, arg ...any) {
println("Entering hook1....")
call.SetData(555)
*(call.Params[0].(*string)) = "olleH%s\n"
(*(call.Params[1].(*[]any)))[0] = "goodcatch"
call.SetParam(0, "olleH%s\n")
p1 := call.GetParam(1).([]any)
p1[0] = "goodcatch"
}

func OnEnterPrintf2(call *fmt.CallContext, format interface{}, arg ...interface{}) {
func OnEnterPrintf2(call fmt.CallContext, format interface{}, arg ...interface{}) {
println("hook2")
for i := 0; i < 10; i++ {
if i == 5 {
Expand All @@ -29,27 +30,27 @@ func OnEnterPrintf2(call *fmt.CallContext, format interface{}, arg ...interface{
}
}

func onEnterSprintf1(call *fmt.CallContext, format string, arg ...any) {
func onEnterSprintf1(call fmt.CallContext, format string, arg ...any) {
print("a1")
}

func onExitSprintf1(call *fmt.CallContext, s string) {
func onExitSprintf1(call fmt.CallContext, s string) {
print("b1")
}

func onEnterSprintf2(call *fmt.CallContext, format string, arg ...any) {
func onEnterSprintf2(call fmt.CallContext, format string, arg ...any) {
print("a2")
_ = call.SkipCall
_ = call.IsSkipCall()
}

func onExitSprintf2(call *fmt.CallContext, s string) {
func onExitSprintf2(call fmt.CallContext, s string) {
println("b2")
}

func onEnterSprintf3(call *fmt.CallContext, format string, arg ...any) {
func onEnterSprintf3(call fmt.CallContext, format string, arg ...any) {
println("a3")
}

func onExitSprintf3(call *fmt.CallContext, s string) {
func onExitSprintf3(call fmt.CallContext, s string) {
print("b3")
}
5 changes: 5 additions & 0 deletions pkg/rules/test/long/sub/p4.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ func p2() {}
func p3(arg1 int, arg2 bool, arg3 float64) (int, bool, float64) {
return arg1, arg2, arg3
}

func TestGetSet(arg1 int, arg2, arg3 bool, arg4 float64, arg5 string,
arg6 interface{}, arg7, arg8 map[int]bool, arg9 chan int, arg10 []int) (int, bool, bool, float64, string, interface{}, map[int]bool, map[int]bool, chan int, []int) {
return arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10
}
20 changes: 10 additions & 10 deletions pkg/rules/test/net_http_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,40 @@ import (
"net/http"
)

func onEnterClientDo(call *http.CallContext, recv *http.Client, req *http.Request) {
println("Client.Do()")
func onEnterClientDo(call http.CallContext, recv *http.Client, req *http.Request) {
println("Before Client.Do()")
}

func onExitClientDo(call *http.CallContext, resp *http.Response, err error) {
func onExitClientDo(call http.CallContext, resp *http.Response, err error) {
panic("deliberately")
}

// arg type has package prefix
func onEnterNewRequestWithContext(call *http.CallContext, ctx context.Context, method, url string, body io.Reader) {
func onEnterNewRequestWithContext(call http.CallContext, ctx context.Context, method, url string, body io.Reader) {
println("NewRequestWithContext()")
}

// many args have one type
func onEnterNewRequest(call *http.CallContext, method, url string, body io.Reader) {
func onEnterNewRequest(call http.CallContext, method, url string, body io.Reader) {
println("NewRequest()")
}

// many args have interface type
func onEnterNewRequest1(call *http.CallContext, a, b interface{}, c interface{}) {
func onEnterNewRequest1(call http.CallContext, a, b interface{}, c interface{}) {
println("NewRequest1()")
}

// only recv arg
func onEnterMaxBytesError(call *http.CallContext, recv *http.MaxBytesError) {
func onEnterMaxBytesError(call http.CallContext, recv *http.MaxBytesError) {
println("MaxBytesError()")
recv.Limit = 4008208820
}

func onExitMaxBytesError(call *http.CallContext, ret string) {
*(call.ReturnVals[0].(*string)) = "Prince of Qin Smashing the Battle line"
func onExitMaxBytesError(call http.CallContext, ret string) {
call.SetReturnVal(0, "Prince of Qin Smashing the Battle line")
}

// use field added by struct rule
func onExitNewRequest(call *http.CallContext, req *http.Request, _ interface{}) {
func onExitNewRequest(call http.CallContext, req *http.Request, _ interface{}) {
println(req.Should)
}
4 changes: 4 additions & 0 deletions pkg/rules/test/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,8 @@ func init() {
api.NewRule("errors", "TestSkip2", "", "onEnterTestSkip2", "onExitTestSkip2").
WithRuleName("testrule").
Register()

api.NewRule("errors", "TestGetSet", "", "onEnterTestGetSet", "onExitTestGetSet").
WithRuleName("testrule").
Register()
}
3 changes: 3 additions & 0 deletions test/errors-test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ func main() {

val := errors.TestSkip2()
fmt.Printf("val%v\n", val)

arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 := errors.TestGetSet(1, true, false, 3.14, "str", nil, map[int]bool{1: true}, map[int]bool{2: true}, make(chan int), []int{1, 2, 3})
fmt.Printf("val%v %v %v %v %v %v %v %v %v %v\n", arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10)
}
1 change: 1 addition & 0 deletions test/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func TestRunErrors(t *testing.T) {
ExpectContains(t, stdout, "ptr<nil>")
ExpectNotContains(t, stdout, "val1024")
ExpectContains(t, stdout, "val1298") // 0x512
ExpectContains(t, stdout, "7632")

text := ReadInstrumentLog(t, "debug_fn_otel_inst_file_p4.go")
re := regexp.MustCompile(".*OtelOnEnterTrampoline_TestSkip.*")
Expand Down
4 changes: 2 additions & 2 deletions tool/instrument/inst_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl
clone := make([]dst.Expr, len(retVals)+1)
clone[0] = shared.Ident(TrampolineCallContextName + varSuffix)
for i := 1; i < len(clone); i++ {
clone[i] = shared.AddressOf(dst.Clone(retVals[i-1]).(dst.Expr))
clone[i] = shared.AddressOf(retVals[i-1])
}
return clone
}())
Expand All @@ -168,7 +168,7 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl
shared.ReturnStmt(retVals),
),
Else: shared.Block(
shared.DeferStmt(dst.Clone(onExitCall).(*dst.CallExpr)),
shared.DeferStmt(onExitCall),
),
}
// Add this trampoline-jump-if as optimization candidates
Expand Down
2 changes: 2 additions & 0 deletions tool/instrument/instrument.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type RuleProcessor struct {
varDecls []dst.Decl
relocated map[string]string
trampolineJumps []*TJump // Optimization candidates
callCtxDecl *dst.GenDecl
callCtxMethods []*dst.FuncDecl
}

func newRuleProcessor(args []string, pkgName string) *RuleProcessor {
Expand Down
10 changes: 6 additions & 4 deletions tool/instrument/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ func replenishCallContextLiteral(tjump *TJump, expr dst.Expr) {
paramLiteral.Elts = names
}

func newCallContext(tjump *TJump) (dst.Expr, error) {
func (rp *RuleProcessor) newCallContextImpl(tjump *TJump) (dst.Expr, error) {
// One line please, otherwise debugging line number will be a nightmare
const newCallContext = `&CallContext{Params:[]interface{}{},ReturnVals:[]interface{}{},SkipCall:false,}`
astRoot, err := shared.ParseAstFromSnippet(newCallContext)
tmpl := fmt.Sprintf("&CallContextImpl%s{Params:[]interface{}{},ReturnVals:[]interface{}{}}",
rp.rule2Suffix[tjump.rule])
astRoot, err := shared.ParseAstFromSnippet(tmpl)
if err != nil {
return nil, fmt.Errorf("failed to parse new CallContext: %w", err)
}
Expand All @@ -147,7 +148,7 @@ func newCallContext(tjump *TJump) (dst.Expr, error) {

func (rp *RuleProcessor) removeOnEnterTrampolineCall(tjump *TJump) error {
// Construct CallContext on the fly and pass to onExit trampoline defer call
callContextExpr, err := newCallContext(tjump)
callContextExpr, err := rp.newCallContextImpl(tjump)
if err != nil {
return fmt.Errorf("failed to construct CallContext: %w", err)
}
Expand Down Expand Up @@ -218,6 +219,7 @@ func (rp *RuleProcessor) optimizeTJumps() (err error) {
// because there might be more than one trampoline-jump-if in the same
// function, they are nested in the else block. See findJumpPoint for
// more details.
// TODO: Remove corresponding CallContextImpl methods
rule := tjump.rule
removedOnExit := false
if rule.OnExit == "" {
Expand Down
Loading

0 comments on commit 846a570

Please sign in to comment.