From 846a5701aa7348bb0be333457a8e5d0824ead6c0 Mon Sep 17 00:00:00 2001 From: Yi Yang Date: Tue, 30 Jul 2024 10:43:01 +0800 Subject: [PATCH] feat: new CallContext APIs (#17) * feat: new CallContext APIs * fix TestGoMongo111 * consolidate dst node replication --- api/api.go | 56 +++----- pkg/rules/mongo/client_setup.go | 5 +- pkg/rules/test/errors_hook.go | 58 +++++--- pkg/rules/test/fmt_hook.go | 27 ++-- pkg/rules/test/long/sub/p4.go | 5 + pkg/rules/test/net_http_hook.go | 20 +-- pkg/rules/test/rule.go | 4 + test/errors-test/main.go | 3 + test/errors_test.go | 1 + tool/instrument/inst_func.go | 4 +- tool/instrument/instrument.go | 2 + tool/instrument/optimize.go | 10 +- tool/instrument/template.go | 58 +++++--- tool/instrument/trampoline.go | 237 +++++++++++++++++++++++++++----- tool/preprocess/dependency.go | 3 + tool/shared/ast.go | 57 +++++++- 16 files changed, 411 insertions(+), 139 deletions(-) diff --git a/api/api.go b/api/api.go index c047d704..b86d863d 100644 --- a/api/api.go +++ b/api/api.go @@ -9,43 +9,21 @@ package api // of the original function call. Modification of the Params and ReturnVals will // affect the original function call thus should be used with caution. -type CallContext struct { - Params []interface{} // Address of parameters of original function - ReturnVals []interface{} // Address of return values of original function - Data interface{} // User defined data - SkipCall bool // Skip the original function call if set to true -} - -func (ctx *CallContext) SetSkipCall(skip bool) { - ctx.SkipCall = skip -} - -func (ctx *CallContext) SetData(data interface{}) { - ctx.Data = data -} - -func (ctx *CallContext) GetData() interface{} { - return ctx.Data -} - -func (ctx *CallContext) SetKeyData(key, val string) { - if ctx.Data == nil { - ctx.Data = make(map[string]string) - } - ctx.Data.(map[string]string)[key] = val -} - -func (ctx *CallContext) GetKeyData(key string) string { - if ctx.Data == nil { - return "" - } - return ctx.Data.(map[string]string)[key] -} - -func (ctx *CallContext) HasKeyData(key string) bool { - if ctx.Data == nil { - return false - } - _, ok := ctx.Data.(map[string]string)[key] - return ok +type CallContext interface { + // Skip the original function call + SetSkipCall(bool) + // Check if the original function call should be skipped + IsSkipCall() bool + // Set the data field, can be used to pass information between OnEnter&OnExit + SetData(interface{}) + // Get the data field, can be used to pass information between OnEnter&OnExit + GetData() interface{} + // Get the original function parameter at index idx + GetParam(idx int) interface{} + // Change the original function parameter at index idx + SetParam(idx int, val interface{}) + // Get the original function return value at index idx + GetReturnVal(idx int) interface{} + // Change the original function return value at index idx + SetReturnVal(idx int, val interface{}) } diff --git a/pkg/rules/mongo/client_setup.go b/pkg/rules/mongo/client_setup.go index e27880da..3e5374c8 100644 --- a/pkg/rules/mongo/client_setup.go +++ b/pkg/rules/mongo/client_setup.go @@ -6,15 +6,16 @@ import ( "context" "errors" "fmt" + "sync" + "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "sync" ) var mongoInstrumenter = BuildMongoOtelInstrumenter() -func mongoOnEnter(call *mongo.CallContext, opts ...*options.ClientOptions) { +func mongoOnEnter(call mongo.CallContext, opts ...*options.ClientOptions) { syncMap := sync.Map{} for _, opt := range opts { hosts := opt.Hosts diff --git a/pkg/rules/test/errors_hook.go b/pkg/rules/test/errors_hook.go index a386f3dc..4d6a2916 100644 --- a/pkg/rules/test/errors_hook.go +++ b/pkg/rules/test/errors_hook.go @@ -7,40 +7,66 @@ import ( "fmt" ) -func onEnterUnwrap(call *errors.CallContext, err error) { +func onEnterUnwrap(call errors.CallContext, err error) { newErr := fmt.Errorf("wrapped: %w", err) - *(call.Params[0].(*error)) = newErr + call.SetParam(0, newErr) } -func onExitUnwrap(call *errors.CallContext, err error) { - e := (*(call.Params[0].(*error))).(interface { +func onExitUnwrap(call errors.CallContext, err error) { + e := call.GetParam(0).(interface { Unwrap() error }) old := e.Unwrap() fmt.Printf("old:%v\n", old) } -func onEnterTestSkip(call *errors.CallContext) { +func onEnterTestSkip(call errors.CallContext) { call.SetSkipCall(true) } -func onExitTestSkipOnly(call *errors.CallContext, _ *int) {} +func onExitTestSkipOnly(call errors.CallContext, _ *int) {} -func onEnterTestSkipOnly(call *errors.CallContext) {} +func onEnterTestSkipOnly(call errors.CallContext) {} -func onEnterP11(call *errors.CallContext) {} -func onEnterP12(call *errors.CallContext) {} +func onEnterP11(call errors.CallContext) {} +func onEnterP12(call errors.CallContext) {} -func onExitP21(call *errors.CallContext) {} -func onExitP22(call *errors.CallContext) {} +func onExitP21(call errors.CallContext) {} +func onExitP22(call errors.CallContext) {} -func onEnterP31(call *errors.CallContext, arg1 int, arg2 bool, arg3 float64) {} -func onExitP31(call *errors.CallContext, arg1 int, arg2 bool, arg3 float64) {} +func onEnterP31(call errors.CallContext, arg1 int, arg2 bool, arg3 float64) {} +func onExitP31(call errors.CallContext, arg1 int, arg2 bool, arg3 float64) {} -func onEnterTestSkip2(call *errors.CallContext) { +func onEnterTestSkip2(call errors.CallContext) { call.SetSkipCall(true) } -func onExitTestSkip2(call *errors.CallContext, _ int) { - *(call.ReturnVals[0].(*int)) = 0x512 +func onExitTestSkip2(call errors.CallContext, _ int) { + call.SetReturnVal(0, 0x512) +} + +func onEnterTestGetSet(call errors.CallContext, arg1 int, arg2, arg3 bool, arg4 float64, arg5 string, arg6 interface{}, arg7, arg8 map[int]bool, arg9 chan int, arg10 []int) { + call.SetParam(0, 7632) + call.SetParam(1, arg2) + call.SetParam(2, arg3) + call.SetParam(3, arg4) + call.SetParam(4, arg5) + call.SetParam(5, arg6) + call.SetParam(6, arg7) + call.SetParam(7, arg8) + call.SetParam(8, arg9) + call.SetParam(9, arg10) +} + +func onExitTestGetSet(call errors.CallContext, arg1 int, arg2 bool, arg3 bool, arg4 float64, arg5 string, arg6 interface{}, arg7 map[int]bool, arg8 map[int]bool, arg9 chan int, arg10 []int) { + call.SetReturnVal(0, arg1) + call.SetReturnVal(1, arg2) + call.SetReturnVal(2, arg3) + call.SetReturnVal(3, arg4) + call.SetReturnVal(4, arg5) + call.SetReturnVal(5, arg6) + call.SetReturnVal(6, arg7) + call.SetReturnVal(7, arg8) + call.SetReturnVal(8, arg9) + call.SetReturnVal(9, arg10) } diff --git a/pkg/rules/test/fmt_hook.go b/pkg/rules/test/fmt_hook.go index 21b4aaa1..3780077a 100644 --- a/pkg/rules/test/fmt_hook.go +++ b/pkg/rules/test/fmt_hook.go @@ -4,23 +4,24 @@ package test import "fmt" -func OnExitPrintf1(call *fmt.CallContext, n int, err error) { +func OnExitPrintf1(call fmt.CallContext, n int, err error) { println("Exiting hook1....") - *(call.ReturnVals[0].(*int)) = 1024 + call.SetReturnVal(0, 1024) v := call.GetData().(int) println(v) } type any = interface{} -func OnEnterPrintf1(call *fmt.CallContext, format string, arg ...any) { +func OnEnterPrintf1(call fmt.CallContext, format string, arg ...any) { println("Entering hook1....") call.SetData(555) - *(call.Params[0].(*string)) = "olleH%s\n" - (*(call.Params[1].(*[]any)))[0] = "goodcatch" + call.SetParam(0, "olleH%s\n") + p1 := call.GetParam(1).([]any) + p1[0] = "goodcatch" } -func OnEnterPrintf2(call *fmt.CallContext, format interface{}, arg ...interface{}) { +func OnEnterPrintf2(call fmt.CallContext, format interface{}, arg ...interface{}) { println("hook2") for i := 0; i < 10; i++ { if i == 5 { @@ -29,27 +30,27 @@ func OnEnterPrintf2(call *fmt.CallContext, format interface{}, arg ...interface{ } } -func onEnterSprintf1(call *fmt.CallContext, format string, arg ...any) { +func onEnterSprintf1(call fmt.CallContext, format string, arg ...any) { print("a1") } -func onExitSprintf1(call *fmt.CallContext, s string) { +func onExitSprintf1(call fmt.CallContext, s string) { print("b1") } -func onEnterSprintf2(call *fmt.CallContext, format string, arg ...any) { +func onEnterSprintf2(call fmt.CallContext, format string, arg ...any) { print("a2") - _ = call.SkipCall + _ = call.IsSkipCall() } -func onExitSprintf2(call *fmt.CallContext, s string) { +func onExitSprintf2(call fmt.CallContext, s string) { println("b2") } -func onEnterSprintf3(call *fmt.CallContext, format string, arg ...any) { +func onEnterSprintf3(call fmt.CallContext, format string, arg ...any) { println("a3") } -func onExitSprintf3(call *fmt.CallContext, s string) { +func onExitSprintf3(call fmt.CallContext, s string) { print("b3") } diff --git a/pkg/rules/test/long/sub/p4.go b/pkg/rules/test/long/sub/p4.go index a22fc353..c4d300b5 100644 --- a/pkg/rules/test/long/sub/p4.go +++ b/pkg/rules/test/long/sub/p4.go @@ -17,3 +17,8 @@ func p2() {} func p3(arg1 int, arg2 bool, arg3 float64) (int, bool, float64) { return arg1, arg2, arg3 } + +func TestGetSet(arg1 int, arg2, arg3 bool, arg4 float64, arg5 string, + arg6 interface{}, arg7, arg8 map[int]bool, arg9 chan int, arg10 []int) (int, bool, bool, float64, string, interface{}, map[int]bool, map[int]bool, chan int, []int) { + return arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 +} diff --git a/pkg/rules/test/net_http_hook.go b/pkg/rules/test/net_http_hook.go index e84bd037..d6a878b4 100644 --- a/pkg/rules/test/net_http_hook.go +++ b/pkg/rules/test/net_http_hook.go @@ -8,40 +8,40 @@ import ( "net/http" ) -func onEnterClientDo(call *http.CallContext, recv *http.Client, req *http.Request) { - println("Client.Do()") +func onEnterClientDo(call http.CallContext, recv *http.Client, req *http.Request) { + println("Before Client.Do()") } -func onExitClientDo(call *http.CallContext, resp *http.Response, err error) { +func onExitClientDo(call http.CallContext, resp *http.Response, err error) { panic("deliberately") } // arg type has package prefix -func onEnterNewRequestWithContext(call *http.CallContext, ctx context.Context, method, url string, body io.Reader) { +func onEnterNewRequestWithContext(call http.CallContext, ctx context.Context, method, url string, body io.Reader) { println("NewRequestWithContext()") } // many args have one type -func onEnterNewRequest(call *http.CallContext, method, url string, body io.Reader) { +func onEnterNewRequest(call http.CallContext, method, url string, body io.Reader) { println("NewRequest()") } // many args have interface type -func onEnterNewRequest1(call *http.CallContext, a, b interface{}, c interface{}) { +func onEnterNewRequest1(call http.CallContext, a, b interface{}, c interface{}) { println("NewRequest1()") } // only recv arg -func onEnterMaxBytesError(call *http.CallContext, recv *http.MaxBytesError) { +func onEnterMaxBytesError(call http.CallContext, recv *http.MaxBytesError) { println("MaxBytesError()") recv.Limit = 4008208820 } -func onExitMaxBytesError(call *http.CallContext, ret string) { - *(call.ReturnVals[0].(*string)) = "Prince of Qin Smashing the Battle line" +func onExitMaxBytesError(call http.CallContext, ret string) { + call.SetReturnVal(0, "Prince of Qin Smashing the Battle line") } // use field added by struct rule -func onExitNewRequest(call *http.CallContext, req *http.Request, _ interface{}) { +func onExitNewRequest(call http.CallContext, req *http.Request, _ interface{}) { println(req.Should) } diff --git a/pkg/rules/test/rule.go b/pkg/rules/test/rule.go index 24535460..42e5488e 100644 --- a/pkg/rules/test/rule.go +++ b/pkg/rules/test/rule.go @@ -179,4 +179,8 @@ func init() { api.NewRule("errors", "TestSkip2", "", "onEnterTestSkip2", "onExitTestSkip2"). WithRuleName("testrule"). Register() + + api.NewRule("errors", "TestGetSet", "", "onEnterTestGetSet", "onExitTestGetSet"). + WithRuleName("testrule"). + Register() } diff --git a/test/errors-test/main.go b/test/errors-test/main.go index bb8727bd..18a3f55a 100644 --- a/test/errors-test/main.go +++ b/test/errors-test/main.go @@ -15,4 +15,7 @@ func main() { val := errors.TestSkip2() fmt.Printf("val%v\n", val) + + arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 := errors.TestGetSet(1, true, false, 3.14, "str", nil, map[int]bool{1: true}, map[int]bool{2: true}, make(chan int), []int{1, 2, 3}) + fmt.Printf("val%v %v %v %v %v %v %v %v %v %v\n", arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10) } diff --git a/test/errors_test.go b/test/errors_test.go index 60b0619f..9eb7295a 100644 --- a/test/errors_test.go +++ b/test/errors_test.go @@ -17,6 +17,7 @@ func TestRunErrors(t *testing.T) { ExpectContains(t, stdout, "ptr") ExpectNotContains(t, stdout, "val1024") ExpectContains(t, stdout, "val1298") // 0x512 + ExpectContains(t, stdout, "7632") text := ReadInstrumentLog(t, "debug_fn_otel_inst_file_p4.go") re := regexp.MustCompile(".*OtelOnEnterTrampoline_TestSkip.*") diff --git a/tool/instrument/inst_func.go b/tool/instrument/inst_func.go index fa4420d1..d5377694 100644 --- a/tool/instrument/inst_func.go +++ b/tool/instrument/inst_func.go @@ -149,7 +149,7 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl clone := make([]dst.Expr, len(retVals)+1) clone[0] = shared.Ident(TrampolineCallContextName + varSuffix) for i := 1; i < len(clone); i++ { - clone[i] = shared.AddressOf(dst.Clone(retVals[i-1]).(dst.Expr)) + clone[i] = shared.AddressOf(retVals[i-1]) } return clone }()) @@ -168,7 +168,7 @@ func (rp *RuleProcessor) insertTJump(t *api.InstFuncRule, funcDecl *dst.FuncDecl shared.ReturnStmt(retVals), ), Else: shared.Block( - shared.DeferStmt(dst.Clone(onExitCall).(*dst.CallExpr)), + shared.DeferStmt(onExitCall), ), } // Add this trampoline-jump-if as optimization candidates diff --git a/tool/instrument/instrument.go b/tool/instrument/instrument.go index 1ed02c48..0fe30fc0 100644 --- a/tool/instrument/instrument.go +++ b/tool/instrument/instrument.go @@ -34,6 +34,8 @@ type RuleProcessor struct { varDecls []dst.Decl relocated map[string]string trampolineJumps []*TJump // Optimization candidates + callCtxDecl *dst.GenDecl + callCtxMethods []*dst.FuncDecl } func newRuleProcessor(args []string, pkgName string) *RuleProcessor { diff --git a/tool/instrument/optimize.go b/tool/instrument/optimize.go index a967de0b..353cfdd6 100644 --- a/tool/instrument/optimize.go +++ b/tool/instrument/optimize.go @@ -132,10 +132,11 @@ func replenishCallContextLiteral(tjump *TJump, expr dst.Expr) { paramLiteral.Elts = names } -func newCallContext(tjump *TJump) (dst.Expr, error) { +func (rp *RuleProcessor) newCallContextImpl(tjump *TJump) (dst.Expr, error) { // One line please, otherwise debugging line number will be a nightmare - const newCallContext = `&CallContext{Params:[]interface{}{},ReturnVals:[]interface{}{},SkipCall:false,}` - astRoot, err := shared.ParseAstFromSnippet(newCallContext) + tmpl := fmt.Sprintf("&CallContextImpl%s{Params:[]interface{}{},ReturnVals:[]interface{}{}}", + rp.rule2Suffix[tjump.rule]) + astRoot, err := shared.ParseAstFromSnippet(tmpl) if err != nil { return nil, fmt.Errorf("failed to parse new CallContext: %w", err) } @@ -147,7 +148,7 @@ func newCallContext(tjump *TJump) (dst.Expr, error) { func (rp *RuleProcessor) removeOnEnterTrampolineCall(tjump *TJump) error { // Construct CallContext on the fly and pass to onExit trampoline defer call - callContextExpr, err := newCallContext(tjump) + callContextExpr, err := rp.newCallContextImpl(tjump) if err != nil { return fmt.Errorf("failed to construct CallContext: %w", err) } @@ -218,6 +219,7 @@ func (rp *RuleProcessor) optimizeTJumps() (err error) { // because there might be more than one trampoline-jump-if in the same // function, they are nested in the else block. See findJumpPoint for // more details. + // TODO: Remove corresponding CallContextImpl methods rule := tjump.rule removedOnExit := false if rule.OnExit == "" { diff --git a/tool/instrument/template.go b/tool/instrument/template.go index 5f6bd75f..a47727cf 100644 --- a/tool/instrument/template.go +++ b/tool/instrument/template.go @@ -2,22 +2,46 @@ package instrument -// @@ Modification on this trampoline template should be cautious, as it imposes -// many implicit constraints on generated code, known constraints are as follows: -// - It's performance critical, so it should be as simple as possible -// - It should not import any package because there is no guarantee that package -// is existed in import config during the compilation, one practical approach -// is to use function variables and setup these variables in preprocess stage -// - It should not panic as this affects user application -// - Function and variable names are coupled with the framework, any modification -// on them should be synced with the framework +// Seeing is not always believing. The following template is a bit tricky, see +// trampoline.go for more details -// Variable Declaration +// Struct Template +type CallContextImpl struct { + Params []interface{} + ReturnVals []interface{} + SkipCall bool + Data interface{} +} + +func (c *CallContextImpl) SetSkipCall(skip bool) { c.SkipCall = skip } +func (c *CallContextImpl) IsSkipCall() bool { return c.SkipCall } +func (c *CallContextImpl) SetData(data interface{}) { c.Data = data } +func (c *CallContextImpl) GetData() interface{} { return c.Data } +func (c *CallContextImpl) GetParam(idx int) interface{} { + switch idx { + } + return nil +} +func (c *CallContextImpl) SetParam(idx int, val interface{}) { + switch idx { + } +} +func (c *CallContextImpl) GetReturnVal(idx int) interface{} { + switch idx { + } + return nil +} +func (c *CallContextImpl) SetReturnVal(idx int, val interface{}) { + switch idx { + } +} + +// Variable Template var OtelGetStackImpl func() []byte = nil var OtelPrintStackImpl func([]byte) = nil -// Function Declaration -func OtelOnEnterTrampoline() (*CallContext, bool) { +// Trampoline Template +func OtelOnEnterTrampoline() (CallContext, bool) { defer func() { if err := recover(); err != nil { println("failed to exec onEnter hook", "OtelOnEnterNamePlaceholder") @@ -30,16 +54,12 @@ func OtelOnEnterTrampoline() (*CallContext, bool) { } } }() - callContext := &CallContext{ - Params: nil, - ReturnVals: nil, - SkipCall: false, - } + callContext := &CallContextImpl{} callContext.Params = []interface{}{} return callContext, callContext.SkipCall } -func OtelOnExitTrampoline(callContext *CallContext) { +func OtelOnExitTrampoline(callContext CallContext) { defer func() { if err := recover(); err != nil { println("failed to exec onExit hook", "OtelOnExitNamePlaceholder") @@ -52,5 +72,5 @@ func OtelOnExitTrampoline(callContext *CallContext) { } } }() - callContext.ReturnVals = []interface{}{} + callContext.(*CallContextImpl).ReturnVals = []interface{}{} } diff --git a/tool/instrument/trampoline.go b/tool/instrument/trampoline.go index 73596890..e6a6b982 100644 --- a/tool/instrument/trampoline.go +++ b/tool/instrument/trampoline.go @@ -7,14 +7,10 @@ import ( "go/token" "path/filepath" - "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/util" - + "github.com/alibaba/opentelemetry-go-auto-instrumentation/api" "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/resource" - "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/shared" - - "github.com/alibaba/opentelemetry-go-auto-instrumentation/api" - + "github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/util" "github.com/dave/dst" ) @@ -31,15 +27,34 @@ import ( // its name suggests, it jumps to the trampoline function from raw function. const ( + TrampolineSetParamName = "SetParam" + TrampolineGetParamName = "GetParam" + TrampolineSetReturnValName = "SetReturnVal" + TrampolineGetReturnValName = "GetReturnVal" + TrampolineValIdentifier = "val" + TrampolineCtxIdentifier = "c" + TrampolineParamsIdentifier = "Params" + TrampolineReturnValsIdentifier = "ReturnVals" TrampolineSkipName = "skip" TrampolineCallContextName = "callContext" TrampolineCallContextType = "CallContext" + TrampolineCallContextImplType = "CallContextImpl" TrampolineOnEnterName = "OtelOnEnterTrampoline" TrampolineOnExitName = "OtelOnExitTrampoline" TrampolineOnEnterNamePlaceholder = "\"OtelOnEnterNamePlaceholder\"" TrampolineOnExitNamePlaceholder = "\"OtelOnExitNamePlaceholder\"" ) +// @@ Modification on this trampoline template should be cautious, as it imposes +// many implicit constraints on generated code, known constraints are as follows: +// - It's performance critical, so it should be as simple as possible +// - It should not import any package because there is no guarantee that package +// is existed in import config during the compilation, one practical approach +// is to use function variables and setup these variables in preprocess stage +// - It should not panic as this affects user application +// - Function and variable names are coupled with the framework, any modification +// on them should be synced with the framework + //go:embed template.go var trampolineTemplate string @@ -50,32 +65,41 @@ func (rp *RuleProcessor) materializeTemplate() error { if err != nil { return fmt.Errorf("failed to parse trampoline template: %w", err) } - varDecls := make([]dst.Decl, 0) - var onEnterDecl, onExitDecl *dst.FuncDecl + + rp.varDecls = make([]dst.Decl, 0) + rp.callCtxMethods = make([]*dst.FuncDecl, 0) for _, node := range astRoot.Decls { // Materialize function declarations if decl, ok := node.(*dst.FuncDecl); ok { if decl.Name.Name == TrampolineOnEnterName { - onEnterDecl = decl + rp.onEnterHookFunc = decl + rp.addDecl(decl) } else if decl.Name.Name == TrampolineOnExitName { - onExitDecl = decl + rp.onExitHookFunc = decl + rp.addDecl(decl) + } else if decl.Recv != nil { + // We know exactly this is CallContextImpl method + t := decl.Recv.List[0].Type.(*dst.StarExpr).X.(*dst.Ident).Name + util.Assert(t == TrampolineCallContextImplType, "sanity check") + rp.callCtxMethods = append(rp.callCtxMethods, decl) + rp.addDecl(decl) } } // Materialize variable declarations if decl, ok := node.(*dst.GenDecl); ok { // No further processing for variable declarations, just append them if decl.Tok == token.VAR { - varDecls = append(varDecls, decl) + rp.varDecls = append(rp.varDecls, decl) + } else if decl.Tok == token.TYPE { + rp.callCtxDecl = decl + rp.addDecl(decl) } } } - util.Assert(len(varDecls) > 0, "sanity check") - util.Assert(onEnterDecl != nil && onExitDecl != nil, "sanity check") - rp.onEnterHookFunc = onEnterDecl - rp.onExitHookFunc = onExitDecl - rp.addDecl(onEnterDecl) - rp.addDecl(onExitDecl) - rp.varDecls = varDecls + util.Assert(rp.callCtxDecl != nil, "sanity check") + util.Assert(len(rp.varDecls) > 0, "sanity check") + util.Assert(rp.onEnterHookFunc != nil, "sanity check") + util.Assert(rp.onExitHookFunc != nil, "sanity check") return nil } @@ -226,7 +250,7 @@ func (rp *RuleProcessor) addOnEnterHookVarDecl(t *api.InstFuncRule, traits []Par // raw function is not exposed, we need to use interface{} to represent them. err := rectifyAnyType(paramTypes, traits) if err != nil { - return fmt.Errorf("failed to rectify any type: %w", err) + return fmt.Errorf("failed to rectify any type on enter: %w", err) } // Generate onEnter var decl @@ -240,7 +264,7 @@ func (rp *RuleProcessor) addOnExitVarHookDecl(t *api.InstFuncRule, traits []Para addCallContext(paramTypes) err := rectifyAnyType(paramTypes, traits) if err != nil { - return fmt.Errorf("failed to rectify any type: %w", err) + return fmt.Errorf("failed to rectify any type on exit: %w", err) } // Generate onExit var decl @@ -282,7 +306,7 @@ func (rp *RuleProcessor) renameFunc(t *api.InstFuncRule) { func addCallContext(list *dst.FieldList) { callCtx := shared.NewField( TrampolineCallContextName, - shared.DereferenceOf(dst.NewIdent(TrampolineCallContextType)), + dst.NewIdent(TrampolineCallContextType), ) list.List = append([]*dst.Field{callCtx}, list.List...) } @@ -320,13 +344,8 @@ func (rp *RuleProcessor) rectifyTypes() { for _, list := range candidate { for i := 0; i < len(list.List); i++ { paramField := list.List[i] - if ft, ok := paramField.Type.(*dst.Ellipsis); ok { - // If parameter is type of ...T, we need to convert it to *[]T - paramField.Type = shared.DereferenceOf(shared.ArrayType(ft.Elt)) - } else { - // Otherwise, convert it to *T as usual - paramField.Type = shared.DereferenceOf(paramField.Type) - } + paramFieldType := desugarType(paramField) + paramField.Type = shared.DereferenceOf(paramFieldType) } } addCallContext(onExitHookFunc.Type.Params) @@ -358,21 +377,175 @@ func (rp *RuleProcessor) replenishCallContext(onEnter bool) bool { } } } + util.ShouldNotReachHereT("failed to replenish call context") return false } +// implementCallContext effectively "implements" the CallContext interface by +// renaming occurrences of CallContextImpl to CallContextImpl{suffix} in the +// trampoline template +func (rp *RuleProcessor) implementCallContext(t *api.InstFuncRule) { + suffix := rp.rule2Suffix[t] + structType := rp.callCtxDecl.Specs[0].(*dst.TypeSpec) + util.Assert(structType.Name.Name == TrampolineCallContextImplType, + "sanity check") + structType.Name.Name += suffix // type declaration + for _, method := range rp.callCtxMethods { // method declaration + method.Recv.List[0].Type.(*dst.StarExpr).X.(*dst.Ident).Name += suffix + } + for _, node := range []dst.Node{rp.onEnterHookFunc, rp.onExitHookFunc} { + dst.Inspect(node, func(node dst.Node) bool { + if ident, ok := node.(*dst.Ident); ok { + if ident.Name == TrampolineCallContextImplType { + ident.Name += suffix + return false + } + } + return true + }) + } +} + +func setValue(field string, idx int, typ dst.Expr) *dst.CaseClause { + // *(c.Params[idx].(*int)) = val.(int) + // c.Params[idx] = val iff type is interface{} + se := shared.SelectorExpr(shared.Ident(TrampolineCtxIdentifier), field) + ie := shared.IndexExpr(se, shared.IntLit(idx)) + te := shared.TypeAssertExpr(ie, shared.DereferenceOf(typ)) + pe := shared.ParenExpr(te) + de := shared.DereferenceOf(pe) + val := shared.Ident(TrampolineValIdentifier) + assign := shared.AssignStmt(de, shared.TypeAssertExpr(val, typ)) + if _, ok := typ.(*dst.InterfaceType); ok { + assign = shared.AssignStmt(ie, val) + } + caseClause := &dst.CaseClause{ + List: shared.Exprs(shared.IntLit(idx)), + Body: shared.Stmts(assign), + } + return caseClause +} + +func getValue(field string, idx int, typ dst.Expr) *dst.CaseClause { + // return *(c.Params[idx].(*int)) + // return c.Params[idx] iff type is interface{} + se := shared.SelectorExpr(shared.Ident(TrampolineCtxIdentifier), field) + ie := shared.IndexExpr(se, shared.IntLit(idx)) + te := shared.TypeAssertExpr(ie, shared.DereferenceOf(typ)) + pe := shared.ParenExpr(te) + de := shared.DereferenceOf(pe) + ret := shared.ReturnStmt(shared.Exprs(de)) + if _, ok := typ.(*dst.InterfaceType); ok { + ret = shared.ReturnStmt(shared.Exprs(ie)) + } + caseClause := &dst.CaseClause{ + List: shared.Exprs(shared.IntLit(idx)), + Body: shared.Stmts(ret), + } + return caseClause +} + +func getParamClause(idx int, typ dst.Expr) *dst.CaseClause { + return getValue(TrampolineParamsIdentifier, idx, typ) +} + +func setParamClause(idx int, typ dst.Expr) *dst.CaseClause { + return setValue(TrampolineParamsIdentifier, idx, typ) +} + +func getReturnValClause(idx int, typ dst.Expr) *dst.CaseClause { + return getValue(TrampolineReturnValsIdentifier, idx, typ) +} + +func setReturnValClause(idx int, typ dst.Expr) *dst.CaseClause { + return setValue(TrampolineReturnValsIdentifier, idx, typ) +} + +// desugarType desugars parameter type to its original type, if parameter +// is type of ...T, it will be converted to []T +func desugarType(param *dst.Field) dst.Expr { + if ft, ok := param.Type.(*dst.Ellipsis); ok { + return shared.ArrayType(ft.Elt) + } + return param.Type +} + +func (rp *RuleProcessor) rewriteCallContextImpl() { + util.Assert(len(rp.callCtxMethods) > 4, "sanity check") + var ( + methodSetParam *dst.FuncDecl + methodGetParam *dst.FuncDecl + methodGetRetVal *dst.FuncDecl + methodSetRetVal *dst.FuncDecl + ) + for _, decl := range rp.callCtxMethods { + switch decl.Name.Name { + case TrampolineSetParamName: + methodSetParam = decl + case TrampolineGetParamName: + methodGetParam = decl + case TrampolineGetReturnValName: + methodGetRetVal = decl + case TrampolineSetReturnValName: + methodSetRetVal = decl + } + } + // Rewrite SetParam and GetParam methods + // Dont believe what you see in template.go, we will null out it and rewrite + // the whole switch statement + methodSetParamBody := methodSetParam.Body.List[0].(*dst.SwitchStmt).Body + methodGetParamBody := methodGetParam.Body.List[0].(*dst.SwitchStmt).Body + methodGetRetValBody := methodGetRetVal.Body.List[0].(*dst.SwitchStmt).Body + methodSetRetValBody := methodSetRetVal.Body.List[0].(*dst.SwitchStmt).Body + methodGetParamBody.List = nil + methodSetParamBody.List = nil + methodGetRetValBody.List = nil + methodSetRetValBody.List = nil + idx := 0 + for _, param := range rp.rawFunc.Type.Params.List { + paramType := desugarType(param) + for range param.Names { + clause := setParamClause(idx, paramType) + methodSetParamBody.List = append(methodSetParamBody.List, clause) + clause = getParamClause(idx, paramType) + methodGetParamBody.List = append(methodGetParamBody.List, clause) + idx++ + } + } + // Rewrite GetReturnVal and SetReturnVal methods + if rp.rawFunc.Type.Results != nil { + idx = 0 + for _, retval := range rp.rawFunc.Type.Results.List { + retType := desugarType(retval) + for range retval.Names { + clause := getReturnValClause(idx, retType) + methodGetRetValBody.List = append(methodGetRetValBody.List, clause) + clause = setReturnValClause(idx, retType) + methodSetRetValBody.List = append(methodSetRetValBody.List, clause) + idx++ + } + } + } +} + func (rp *RuleProcessor) generateTrampoline(t *api.InstFuncRule, funcDecl *dst.FuncDecl) error { rp.rawFunc = funcDecl - // Materialize trampoline template + // Materialize various declarations from template file, no one wants to see + // a bunch of manual AST code generation, isn't it? err := rp.materializeTemplate() if err != nil { return fmt.Errorf("failed to materialize template: %w", err) } - // Rename onEnter and onExit trampoline function names + // Implement CallContext interface + rp.implementCallContext(t) + // Rewrite type-aware CallContext APIs + rp.rewriteCallContextImpl() + + // Rename trampoline functions rp.renameFunc(t) - // Rectify types of onEnter and onExit trampoline funcs + // Rectify types of trampoline functions rp.rectifyTypes() - // Generate calls to onEnter and onExit hooks + // Generate calls to hook functions within trampoline functions if t.OnEnter != "" { traits, err := getHookParamTraits(t, true) if err != nil { diff --git a/tool/preprocess/dependency.go b/tool/preprocess/dependency.go index 80b4aac9..eea5b6ba 100644 --- a/tool/preprocess/dependency.go +++ b/tool/preprocess/dependency.go @@ -77,6 +77,9 @@ func (dp *DepProcessor) postProcess() { // rm -rf otel_rules _ = os.RemoveAll(OtelRules) + // rm -rf otel_pkgdep + _ = os.RemoveAll(OtelPkgDepsDir) + // Restore everything we have modified during instrumentation err := dp.restoreBackupFiles() if err != nil { diff --git a/tool/shared/ast.go b/tool/shared/ast.go index fdd4e9a6..a05e2298 100644 --- a/tool/shared/ast.go +++ b/tool/shared/ast.go @@ -14,7 +14,7 @@ import ( // AST Construction func AddressOf(expr dst.Expr) *dst.UnaryExpr { - return &dst.UnaryExpr{Op: token.AND, X: expr} + return &dst.UnaryExpr{Op: token.AND, X: dst.Clone(expr).(dst.Expr)} } func CallTo(name string, args []dst.Expr) *dst.CallExpr { @@ -39,6 +39,20 @@ func Ident(name string) *dst.Ident { } } +func StringLit(value string) *dst.BasicLit { + return &dst.BasicLit{ + Kind: token.STRING, + Value: fmt.Sprintf("%q", value), + } +} + +func IntLit(value int) *dst.BasicLit { + return &dst.BasicLit{ + Kind: token.INT, + Value: fmt.Sprintf("%d", value), + } +} + func Block(stmt dst.Stmt) *dst.BlockStmt { return &dst.BlockStmt{ List: []dst.Stmt{ @@ -57,6 +71,37 @@ func Exprs(exprs ...dst.Expr) []dst.Expr { return exprs } +func Stmts(stmts ...dst.Stmt) []dst.Stmt { + return stmts +} + +func SelectorExpr(x dst.Expr, sel string) *dst.SelectorExpr { + return &dst.SelectorExpr{ + X: dst.Clone(x).(dst.Expr), + Sel: Ident(sel), + } +} + +func IndexExpr(x dst.Expr, index dst.Expr) *dst.IndexExpr { + return &dst.IndexExpr{ + X: dst.Clone(x).(dst.Expr), + Index: dst.Clone(index).(dst.Expr), + } +} + +func TypeAssertExpr(x dst.Expr, typ dst.Expr) *dst.TypeAssertExpr { + return &dst.TypeAssertExpr{ + X: x, + Type: dst.Clone(typ).(dst.Expr), + } +} + +func ParenExpr(x dst.Expr) *dst.ParenExpr { + return &dst.ParenExpr{ + X: dst.Clone(x).(dst.Expr), + } +} + func NewField(name string, typ dst.Expr) *dst.Field { newField := &dst.Field{ Names: []*dst.Ident{dst.NewIdent(name)}, @@ -90,13 +135,21 @@ func ExprStmt(expr dst.Expr) *dst.ExprStmt { } func DeferStmt(call *dst.CallExpr) *dst.DeferStmt { - return &dst.DeferStmt{Call: call} + return &dst.DeferStmt{Call: dst.Clone(call).(*dst.CallExpr)} } func ReturnStmt(results []dst.Expr) *dst.ReturnStmt { return &dst.ReturnStmt{Results: results} } +func AssignStmt(lhs, rhs dst.Expr) *dst.AssignStmt { + return &dst.AssignStmt{ + Lhs: []dst.Expr{lhs}, + Tok: token.ASSIGN, + Rhs: []dst.Expr{rhs}, + } +} + func AddStructField(decl dst.Decl, name string, typ string) { gen, ok := decl.(*dst.GenDecl) if !ok {