From 2c003e4f895ef51b9c22e7c5bb5018dd6007ff96 Mon Sep 17 00:00:00 2001 From: xhd2015 Date: Sat, 6 Apr 2024 14:20:22 +0800 Subject: [PATCH] support mocking variables and constants --- README.md | 12 +- README_zh_cn.md | 12 +- cmd/xgo/exec_tool/debug.go | 1 + cmd/xgo/exec_tool/env.go | 2 + cmd/xgo/main.go | 11 +- cmd/xgo/patch/runtime_def.go | 15 +- cmd/xgo/patch/runtime_def_gen.go | 2 + cmd/xgo/version.go | 4 +- patch/adapter_go1.17_18.go | 3 - patch/adapter_go1.17_18_19.go | 13 - patch/adapter_go1.19.go | 3 - patch/ctxt/ctx.go | 18 +- patch/ctxt/skip_pkg_go1.19_and_below.go | 15 + patch/ctxt/skip_pkg_go1.20.go | 8 + patch/link_name.go | 13 +- patch/syntax/func_stub_def.go | 4 +- patch/syntax/helper_code.go | 9 +- patch/syntax/helper_code_gen.go | 13 +- patch/syntax/rewrite.go | 18 + patch/syntax/skip_go1.17_18_19.go | 17 - patch/syntax/skip_go1.20_and_above.go | 8 - patch/syntax/syntax.go | 214 +++++--- patch/syntax/vars.go | 532 +++++++++++++++++++ patch/trap.go | 12 +- patch/trap_runtime/xgo_trap.go | 18 + runtime/core/func.go | 32 +- runtime/core/version.go | 4 +- runtime/functab/functab.go | 32 +- runtime/mock/mock.go | 8 +- runtime/mock/patch.go | 22 +- runtime/test/debug/debug_test.go | 19 +- runtime/test/func_list/func_list_test.go | 6 +- runtime/test/func_list/func_list_var_test.go | 46 ++ runtime/test/func_list/sub/sub.go | 5 + runtime/test/mock_var/mock_var_test.go | 38 ++ runtime/test/mock_var/sub/sub.go | 3 + runtime/test/patch/patch_var_test.go | 19 + runtime/trap/inspect.go | 6 + runtime/trap/trap.go | 69 ++- script/build-compiler/main.go | 1 + script/run-test/main.go | 5 +- script/setup-dev/main.go | 1 + support/goinfo/mod.go | 120 +++++ support/goinfo/mod_test.go | 29 + support/osinfo/osinfo_nonwin.go | 2 + support/osinfo/osinfo_win.go | 2 + 46 files changed, 1245 insertions(+), 201 deletions(-) create mode 100644 patch/ctxt/skip_pkg_go1.19_and_below.go create mode 100644 patch/ctxt/skip_pkg_go1.20.go delete mode 100644 patch/syntax/skip_go1.17_18_19.go delete mode 100644 patch/syntax/skip_go1.20_and_above.go create mode 100644 patch/syntax/vars.go create mode 100644 runtime/test/func_list/func_list_var_test.go create mode 100644 runtime/test/func_list/sub/sub.go create mode 100644 runtime/test/mock_var/mock_var_test.go create mode 100644 runtime/test/mock_var/sub/sub.go create mode 100644 runtime/test/patch/patch_var_test.go create mode 100644 support/goinfo/mod.go create mode 100644 support/goinfo/mod_test.go diff --git a/README.md b/README.md index 63e34bbc..7960039b 100644 --- a/README.md +++ b/README.md @@ -41,12 +41,12 @@ There are other options,see [doc/INSTALLATION.md](./doc/INSTALLATION.md). There is no specific limitation on OS and Architecture. **All OS and Architectures** are supported by `xgo` as long as they are supported by `go`. -| | x86_64 | ARM64 | Any Other Arch... | -|---------|-----------|-----------|-----------| -| Linux | Y | Y | Y| -| Windows | Y | Y | Y| -| macOS | Y | Y | Y| -| Any Other OS... | Y | Y | Y| +| | x86 | x86_64 (amd64) | arm64 | any other Arch... | +|:---------|:-----------:|:-----------:|:-----------:|:-----------:| +| Linux | Y | Y | Y | Y | +| Windows | Y | Y | Y | Y | +| macOS | Y | Y | Y | Y | +| any other OS... | Y | Y | Y | Y| # Quick Start Let's write a unit test with `xgo`: diff --git a/README_zh_cn.md b/README_zh_cn.md index e71d05f4..c2ce53a9 100644 --- a/README_zh_cn.md +++ b/README_zh_cn.md @@ -39,12 +39,12 @@ xgo version 对OS和Arch没有限制, `xgo`支持所有`go`支持的OS和Arch。 -| | x86_64 | ARM64 | 任何其他架构... | -|---------|-----------|-----------|-----------| -| Linux | Y | Y | Y| -| Windows | Y | Y | Y| -| macOS | Y | Y | Y| -| 任何其他OS... | Y | Y | Y| +| | x86 | x86_64 (amd64) | arm64 | 任何其他架构... | +|:---------|:-----------:|:-----------:|:-----------:|:-----------:| +| Linux | Y | Y | Y | Y | +| Windows | Y | Y | Y | Y | +| macOS | Y | Y | Y | Y | +| 任何其他OS... | Y | Y | Y | Y| # 快速开始 我们基于`xgo`编写一个单元测试: diff --git a/cmd/xgo/exec_tool/debug.go b/cmd/xgo/exec_tool/debug.go index a189e345..6e0cd9bf 100644 --- a/cmd/xgo/exec_tool/debug.go +++ b/cmd/xgo/exec_tool/debug.go @@ -18,6 +18,7 @@ func getDebugEnv(xgoCompilerEnableEnv string) map[string]string { XGO_DEBUG_DUMP_AST: os.Getenv(XGO_DEBUG_DUMP_AST), XGO_DEBUG_DUMP_AST_FILE: os.Getenv(XGO_DEBUG_DUMP_AST_FILE), "GOCACHE": os.Getenv("GOCACHE"), + XGO_MAIN_MODULE: os.Getenv(XGO_MAIN_MODULE), "GOROOT": "../..", "PATH": "../../bin:${env:PATH}", "XGO_COMPILER_ENABLE": xgoCompilerEnableEnv, diff --git a/cmd/xgo/exec_tool/env.go b/cmd/xgo/exec_tool/env.go index 0e91c4d8..aac3a7f9 100644 --- a/cmd/xgo/exec_tool/env.go +++ b/cmd/xgo/exec_tool/env.go @@ -11,3 +11,5 @@ const XGO_DEBUG_VSCODE = "XGO_DEBUG_VSCODE" const XGO_TOOLCHAIN_VERSION = "XGO_TOOLCHAIN_VERSION" const XGO_TOOLCHAIN_REVISION = "XGO_TOOLCHAIN_REVISION" const XGO_TOOLCHAIN_VERSION_NUMBER = "XGO_TOOLCHAIN_VERSION_NUMBER" + +const XGO_MAIN_MODULE = "XGO_MAIN_MODULE" diff --git a/cmd/xgo/main.go b/cmd/xgo/main.go index bd19300d..641f5a50 100644 --- a/cmd/xgo/main.go +++ b/cmd/xgo/main.go @@ -353,8 +353,17 @@ func handleBuild(cmd string, args []string) error { go tailLog(compileLog) } + execCmdEnv := os.Environ() var execCmd *exec.Cmd if !cmdExec { + mainModule, err := goinfo.ResolveMainModule(projectDir, remainArgs) + if err != nil { + if !errors.Is(err, goinfo.ErrGoModNotFound) && !errors.Is(err, goinfo.ErrGoModDoesNotHaveModule) { + return err + } + } + logDebug("resolved main module: %s", mainModule) + execCmdEnv = append(execCmdEnv, "XGO_MAIN_MODULE="+mainModule) // GOCACHE="$shdir/build-cache" PATH=$goroot/bin:$PATH GOROOT=$goroot DEBUG_PKG=$debug go build -toolexec="$shdir/exce_tool $cmd" "${build_flags[@]}" "$@" buildCmdArgs := []string{cmd} if toolExecFlag != "" { @@ -401,7 +410,7 @@ func handleBuild(cmd string, args []string) error { logDebug("command: %v", remainArgs) execCmd = exec.Command(remainArgs[0], remainArgs[1:]...) } - execCmd.Env = os.Environ() + execCmd.Env = execCmdEnv execCmd.Env, err = patchEnvWithGoroot(execCmd.Env, instrumentGoroot) if err != nil { return err diff --git a/cmd/xgo/patch/runtime_def.go b/cmd/xgo/patch/runtime_def.go index 9d9f2029..180bb410 100644 --- a/cmd/xgo/patch/runtime_def.go +++ b/cmd/xgo/patch/runtime_def.go @@ -79,7 +79,7 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { for _, n := range noders { files = append(files, n.file) } - xgo_syntax.AfterFilesParsed(files, func(name string, r io.Reader) { + xgo_syntax.AfterFilesParsed(files, func(name string, r io.Reader) *syntax.File { p := &noder{ err: make(chan syntax.Error), } @@ -88,10 +88,11 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { if err != nil { e := err.(syntax.Error) p.error(e) - return + return nil } p.file = file noders = append(noders, p) + return file }) } ` @@ -102,17 +103,18 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { for _, n := range noders { files = append(files, n.file) } - xgo_syntax.AfterFilesParsed(files, func(name string, r io.Reader) { + xgo_syntax.AfterFilesParsed(files, func(name string, r io.Reader) *syntax.File { p := &noder{} fbase := syntax.NewFileBase(name) file, err := syntax.Parse(fbase, r, nil, p.pragma, syntax.CheckBranches) if err != nil { e := err.(syntax.Error) base.ErrorfAt(p.makeXPos(e.Pos), "%s", e.Msg) - return + return nil } p.file = file noders = append(noders, p) + return file }) } ` @@ -123,17 +125,18 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { for _, n := range noders { files = append(files, n.file) } - xgo_syntax.AfterFilesParsed(files, func(name string, r io.Reader) { + xgo_syntax.AfterFilesParsed(files, func(name string, r io.Reader) *syntax.File { p := &noder{} fbase := syntax.NewFileBase(name) file, err := syntax.Parse(fbase, r, nil, p.pragma, syntax.CheckBranches) if err != nil { e := err.(syntax.Error) base.ErrorfAt(m.makeXPos(e.Pos), 0,"%s", e.Msg) - return + return nil } p.file = file noders = append(noders, p) + return file }) } ` diff --git a/cmd/xgo/patch/runtime_def_gen.go b/cmd/xgo/patch/runtime_def_gen.go index 76b07959..d11046d3 100644 --- a/cmd/xgo/patch/runtime_def_gen.go +++ b/cmd/xgo/patch/runtime_def_gen.go @@ -8,7 +8,9 @@ const RuntimeExtraDef = ` func __xgo_getcurg() unsafe.Pointer func __xgo_trap(pkgPath string, identityName string, generic bool, recv interface{}, args []interface{}, results []interface{}) (func(), bool) func __xgo_trap_for_generated(pkgPath string, pc uintptr, identityName string, generic bool, recv interface{}, args []interface{}, results []interface{}) (func(), bool) +func __xgo_trap_var_for_generated(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) func __xgo_set_trap(trap func(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool)) +func __xgo_set_trap_var(trap func(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool)) func __xgo_register_func(info interface{}) func __xgo_retrieve_all_funcs_and_clear(f func(info interface{})) func __xgo_init_finished() bool diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index cb56d9fb..5e124344 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -3,8 +3,8 @@ package main import "fmt" const VERSION = "1.0.18" -const REVISION = "03d82b3e31832e5947c5d3a7ef8752f4f39db28c+1" -const NUMBER = 162 +const REVISION = "1211c519c8005ddbd66189cf64e958aa69e5789f+1" +const NUMBER = 163 func getRevision() string { revSuffix := "" diff --git a/patch/adapter_go1.17_18.go b/patch/adapter_go1.17_18.go index f57747c2..0d46f38e 100644 --- a/patch/adapter_go1.17_18.go +++ b/patch/adapter_go1.17_18.go @@ -36,8 +36,5 @@ func wrapListType(expr *ir.CompLitExpr) *ir.CompLitExpr { } func canInsertTrap(fn *ir.Func) bool { - if isSkippableSpecialPkg() { - return false - } return true } diff --git a/patch/adapter_go1.17_18_19.go b/patch/adapter_go1.17_18_19.go index 24cf220d..88d6183c 100644 --- a/patch/adapter_go1.17_18_19.go +++ b/patch/adapter_go1.17_18_19.go @@ -2,16 +2,3 @@ // +build go1.17,!go1.20 package patch - -import ( - xgo_ctxt "cmd/compile/internal/xgo_rewrite_internal/patch/ctxt" - "strings" -) - -func isSkippableSpecialPkg() bool { - curPkgPath := xgo_ctxt.GetPkgPath() - if strings.HasPrefix(curPkgPath, "golang.org/x/") { - return true - } - return false -} diff --git a/patch/adapter_go1.19.go b/patch/adapter_go1.19.go index e5dcf0f3..ca6417fd 100644 --- a/patch/adapter_go1.19.go +++ b/patch/adapter_go1.19.go @@ -34,9 +34,6 @@ func canInsertTrap(fn *ir.Func) bool { if curPkgPath != fnPkgPath { return false } - if isSkippableSpecialPkg() { - return false - } // fnName := fn.Sym().Name // if strings.Contains(fnName, "[") && strings.Contains(fnName, "]") { // return false diff --git a/patch/ctxt/ctx.go b/patch/ctxt/ctx.go index c33bb27c..8532b89e 100644 --- a/patch/ctxt/ctx.go +++ b/patch/ctxt/ctx.go @@ -2,6 +2,7 @@ package ctxt import ( "cmd/compile/internal/base" + "os" "strings" ) @@ -9,7 +10,16 @@ const XgoModule = "github.com/xhd2015/xgo" const XgoRuntimePkg = XgoModule + "/runtime" const XgoRuntimeCorePkg = XgoModule + "/runtime/core" +var XgoMainModule = os.Getenv("XGO_MAIN_MODULE") + func SkipPackageTrap() bool { + pkgPath := GetPkgPath() + if pkgPath == "" { + return true + } + if strings.HasPrefix(pkgPath, "runtime/") || strings.HasPrefix(pkgPath, "internal/") { + return true + } if base.Flag.Std { // skip std lib, especially skip: // runtime, runtime/internal, runtime/*, reflect, unsafe, syscall, sync, sync/atomic, internal/* @@ -24,14 +34,15 @@ func SkipPackageTrap() bool { // func may be a foreigner. // allow http - pkgPath := GetPkgPath() if _, ok := stdWhitelist[pkgPath]; ok { return false } return true } + if isSkippableSpecialPkg(pkgPath) { + return true + } - pkgPath := GetPkgPath() if IsPkgXgoSkipTrap(pkgPath) { return true } @@ -111,6 +122,7 @@ func AllowPkgFuncTrap(pkgPath string, isStd bool, funcName string) bool { return true } +// skip all packages for xgo,except test func IsPkgXgoSkipTrap(pkg string) bool { suffix, ok := cutPkgPrefix(pkg, XgoModule) if !ok { @@ -119,7 +131,7 @@ func IsPkgXgoSkipTrap(pkg string) bool { if suffix == "" { return true } - // check if the package is test, runtime/test + // check if the package is test or runtime/test _, ok = cutPkgPrefix(suffix, "test") if ok { return false diff --git a/patch/ctxt/skip_pkg_go1.19_and_below.go b/patch/ctxt/skip_pkg_go1.19_and_below.go new file mode 100644 index 00000000..745122dd --- /dev/null +++ b/patch/ctxt/skip_pkg_go1.19_and_below.go @@ -0,0 +1,15 @@ +//go:build go1.17 && !go1.20 +// +build go1.17,!go1.20 + +package ctx + +import ( + "strings" +) + +func isSkippableSpecialPkg(pkgPath string) bool { + if strings.HasPrefix(pkgPath, "golang.org/x/") { + return true + } + return false +} diff --git a/patch/ctxt/skip_pkg_go1.20.go b/patch/ctxt/skip_pkg_go1.20.go new file mode 100644 index 00000000..1e122d44 --- /dev/null +++ b/patch/ctxt/skip_pkg_go1.20.go @@ -0,0 +1,8 @@ +//go:build go1.20 +// +build go1.20 + +package ctxt + +func isSkippableSpecialPkg(pkgPath string) bool { + return false +} diff --git a/patch/link_name.go b/patch/link_name.go index 8672af8b..4de3f275 100644 --- a/patch/link_name.go +++ b/patch/link_name.go @@ -19,16 +19,21 @@ const xgoRuntimeTrapPkg = xgoRuntimePkgPrefix + "trap" const xgoOnTestStart = "__xgo_on_test_start" const XgoLinkSetTrap = "__xgo_link_set_trap" +const XgoLinkSetTrapVar = "__xgo_link_set_trap_var" const XgoTrapForGenerated = "__xgo_trap_for_generated" const setTrap = "__xgo_set_trap" +const setTrapVar = "__xgo_set_trap_var" +const XgoTrapVarForGenerated = "__xgo_trap_var_for_generated" // only allowed from reflect const reflectSetImpl = "__xgo_set_all_method_by_name_impl" var linkMap = map[string]string{ "__xgo_link_getcurg": "__xgo_getcurg", - "__xgo_link_set_trap": setTrap, + XgoLinkSetTrap: setTrap, + XgoLinkSetTrapVar: setTrapVar, xgo_syntax.XgoLinkTrapForGenerated: XgoTrapForGenerated, + "__xgo_link_trap_var_for_generated": XgoTrapVarForGenerated, "__xgo_link_init_finished": "__xgo_init_finished", "__xgo_link_on_init_finished": "__xgo_on_init_finished", "__xgo_link_on_gonewproc": "__xgo_on_gonewproc", @@ -64,7 +69,7 @@ func isLinkValid(fnName string, targetName string, pkgPath string) bool { return pkgPath == "reflect" } - isLinkTrap := fnName == XgoLinkSetTrap + isLinkTrap := fnName == XgoLinkSetTrap || fnName == XgoLinkSetTrapVar if isLinkTrap { // the special trap return pkgPath == xgoRuntimeTrapPkg || strings.HasPrefix(pkgPath, xgoTestPkgPrefix) @@ -118,6 +123,10 @@ func replaceWithRuntimeCall(fn *ir.Func, name string) { getCallerPC := typecheck.LookupRuntime("getcallerpc") paramNames[1] = ir.NewCallExpr(fn.Pos(), ir.OCALL, getCallerPC, nil) } + if name == XgoTrapVarForGenerated { + // set pos to auto generated + fn.SetPos(base.AutogeneratedPos) + } resNames := getTypeNames(results) fnPos := fn.Pos() diff --git a/patch/syntax/func_stub_def.go b/patch/syntax/func_stub_def.go index fe3cc69e..2bb578d7 100644 --- a/patch/syntax/func_stub_def.go +++ b/patch/syntax/func_stub_def.go @@ -2,8 +2,10 @@ package syntax const expected__xgo_stub_def = `struct { PkgPath string + Kind int // 0 = func, 1 = var, 2=var_ptr 3 = const Fn interface{} - PC uintptr // filled later + Var interface{} // pointer to a variable if this is a declare variable + PC uintptr // filled later Interface bool Generic bool Closure bool // is the given function a closure diff --git a/patch/syntax/helper_code.go b/patch/syntax/helper_code.go index a87136f4..d0e39d8b 100644 --- a/patch/syntax/helper_code.go +++ b/patch/syntax/helper_code.go @@ -3,10 +3,14 @@ package syntax +const __xgo_local_pkg_name = "" // filled later + type __xgo_local_func_stub struct { PkgPath string + Kind int // 0 = func, 1 = var, 2=var_ptr 3 = const Fn interface{} - PC uintptr // filled later + Var interface{} // pointer to a variable if this is a declare variable + PC uintptr // filled later Interface bool Generic bool Closure bool // is the given function a closure @@ -46,6 +50,9 @@ func __xgo_link_trap_for_generated(pkgPath string, pc uintptr, identityName stri // linked by compiler return nil, false } +func __xgo_link_trap_var_for_generated(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) { + // linked by compiler +} func __xgo_link_generated_register_func(fn interface{}) { // linked later by compiler diff --git a/patch/syntax/helper_code_gen.go b/patch/syntax/helper_code_gen.go index 9c4ada78..04161a4e 100755 --- a/patch/syntax/helper_code_gen.go +++ b/patch/syntax/helper_code_gen.go @@ -4,8 +4,10 @@ package syntax const __xgo_stub_def = `struct { PkgPath string + Kind int // 0 = func, 1 = var, 2=var_ptr 3 = const Fn interface{} - PC uintptr // filled later + Var interface{} // pointer to a variable if this is a declare variable + PC uintptr // filled later Interface bool Generic bool Closure bool // is the given function a closure @@ -30,10 +32,14 @@ const __xgo_stub_def = `struct { const helperCodeGen = ` +const __xgo_local_pkg_name = "" // filled later + type __xgo_local_func_stub struct { PkgPath string + Kind int // 0 = func, 1 = var, 2=var_ptr 3 = const Fn interface{} - PC uintptr // filled later + Var interface{} // pointer to a variable if this is a declare variable + PC uintptr // filled later Interface bool Generic bool Closure bool // is the given function a closure @@ -73,6 +79,9 @@ func __xgo_link_trap_for_generated(pkgPath string, pc uintptr, identityName stri // linked by compiler return nil, false } +func __xgo_link_trap_var_for_generated(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) { + // linked by compiler +} func __xgo_link_generated_register_func(fn interface{}) { // linked later by compiler diff --git a/patch/syntax/rewrite.go b/patch/syntax/rewrite.go index e73653b6..e450b7b2 100644 --- a/patch/syntax/rewrite.go +++ b/patch/syntax/rewrite.go @@ -11,6 +11,7 @@ import ( const XgoLinkTrapForGenerated = "__xgo_link_trap_for_generated" +// for closure func fillFuncArgResNames(fileList []*syntax.File) { if base.Flag.Std { return @@ -35,6 +36,9 @@ func fillFuncArgResNames(fileList []*syntax.File) { func rewriteStdAndGenericFuncs(funcDecls []*DeclInfo, pkgPath string) { for _, fn := range funcDecls { + if !fn.Kind.IsFunc() { + continue + } if fn.Interface { continue } @@ -491,12 +495,26 @@ func newStringLit(s string) *syntax.BasicLit { Kind: syntax.StringLit, } } +func takeNameAddr(pos syntax.Pos, name string) *syntax.Operation { + return takeExprAddr(syntax.NewName(pos, name)) +} + +func takeExprAddr(expr syntax.Expr) *syntax.Operation { + return &syntax.Operation{ + Op: syntax.And, + X: expr, + } +} + func newIntLit(i int) *syntax.BasicLit { return &syntax.BasicLit{ Value: strconv.FormatInt(int64(i), 10), Kind: syntax.IntLit, } } +func newBool(pos syntax.Pos, b bool) *syntax.Name { + return syntax.NewName(pos, strconv.FormatBool(b)) +} // func newBoolLit(b bool) *syntax.BasicLit { // return &syntax.BasicLit{ diff --git a/patch/syntax/skip_go1.17_18_19.go b/patch/syntax/skip_go1.17_18_19.go deleted file mode 100644 index 3ebeea7d..00000000 --- a/patch/syntax/skip_go1.17_18_19.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build go1.17 && !go1.20 -// +build go1.17,!go1.20 - -package syntax - -import ( - xgo_ctxt "cmd/compile/internal/xgo_rewrite_internal/patch/ctxt" - "strings" -) - -func isSkippableSpecialPkg() bool { - curPkgPath := xgo_ctxt.GetPkgPath() - if strings.HasPrefix(curPkgPath, "golang.org/x/") { - return true - } - return false -} diff --git a/patch/syntax/skip_go1.20_and_above.go b/patch/syntax/skip_go1.20_and_above.go deleted file mode 100644 index 9381e856..00000000 --- a/patch/syntax/skip_go1.20_and_above.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build go1.20 -// +build go1.20 - -package syntax - -func isSkippableSpecialPkg() bool { - return false -} diff --git a/patch/syntax/syntax.go b/patch/syntax/syntax.go index d14a2923..35815fb0 100644 --- a/patch/syntax/syntax.go +++ b/patch/syntax/syntax.go @@ -13,6 +13,21 @@ import ( xgo_func_name "cmd/compile/internal/xgo_rewrite_internal/patch/func_name" ) +const XGO_TOOLCHAIN_VERSION = "XGO_TOOLCHAIN_VERSION" +const XGO_TOOLCHAIN_REVISION = "XGO_TOOLCHAIN_REVISION" +const XGO_TOOLCHAIN_VERSION_NUMBER = "XGO_TOOLCHAIN_VERSION_NUMBER" + +const XGO_VERSION = "XGO_VERSION" +const XGO_REVISION = "XGO_REVISION" +const XGO_NUMBER = "XGO_NUMBER" + +// this link function is considered safe as we do not allow user +// to define such one,there will no abuse +const XgoLinkGeneratedRegisterFunc = "__xgo_link_generated_register_func" +const XgoRegisterFuncs = "__xgo_register_funcs" +const XgoLocalFuncStub = "__xgo_local_func_stub" +const XgoLocalPkgName = "__xgo_local_pkg_name" + const sig_expected__xgo_register_func = `func(info interface{})` func init() { @@ -21,11 +36,11 @@ func init() { } } -func AfterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Reader)) { +func AfterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Reader) *syntax.File) { debugSyntax(fileList) patchVersions(fileList) fillFuncArgResNames(fileList) - afterFilesParsed(fileList, addFile) + registerFuncs(fileList, addFile) } func GetSyntaxDeclMapping() map[string]map[LineCol]*DeclInfo { @@ -69,6 +84,9 @@ func getSyntaxDeclMapping() map[string]map[LineCol]*DeclInfo { if syntaxDecl.Interface { continue } + if !syntaxDecl.Kind.IsFunc() { + continue + } file := syntaxDecl.File fileMapping := syntaxDeclMapping[file] if fileMapping == nil { @@ -111,17 +129,9 @@ func ClearSyntaxDeclMapping() { syntaxDeclMapping = nil } -const xgoRuntimePkgPrefix = "github.com/xhd2015/xgo/runtime/" - -// this link function is considered safe as we do not allow user -// to define such one,there will no abuse -const XgoLinkGeneratedRegisterFunc = "__xgo_link_generated_register_func" -const XgoRegisterFuncs = "__xgo_register_funcs" -const XgoLocalFuncStub = "__xgo_local_func_stub" - -func afterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Reader)) { +func registerFuncs(fileList []*syntax.File, addFile func(name string, r io.Reader) *syntax.File) { allFiles = fileList - if !shouldTrap() { + if xgo_ctxt.SkipPackageTrap() { return } var pkgName string @@ -141,7 +151,9 @@ func afterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Re // complexity, and runtime can be compiled or cached, we cannot locate // where its _pkg_.a is. - funcDelcs := getFuncDecls(fileList) + varTrap := allowVarTrap() + + funcDelcs := getFuncDecls(fileList, varTrap) for _, funcDecl := range funcDelcs { if funcDecl.RecvTypeName == "" && funcDecl.Name == XgoLinkGeneratedRegisterFunc { // ensure we are safe @@ -157,9 +169,23 @@ func afterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Re // std lib functions rewriteStdAndGenericFuncs(funcDelcs, pkgPath) + if varTrap { + trapVariables(fileList, funcDelcs) + // debug + // fmt.Fprintf(os.Stderr, "ast:") + // syntax.Fdump(os.Stderr, fileList[0]) + } + // always generate a helper to aid IR - addFile("__xgo_autogen_register_func_helper.go", strings.NewReader(generateRegHelperCode(pkgName))) + helperFile := addFile("__xgo_autogen_register_func_helper.go", strings.NewReader(generateRegHelperCode(pkgName))) + // change __xgo_local_pkg_name + for _, decl := range helperFile.DeclList { + if constDecl, ok := decl.(*syntax.ConstDecl); ok && constDecl.NameList[0].Value == XgoLocalPkgName { + constDecl.Values.(*syntax.BasicLit).Value = strconv.Quote(pkgPath) + break + } + } // split fileDecls to a list of batch // when statements gets large, it will // exceeds the compiler's threshold, causing @@ -193,14 +219,6 @@ func afterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Re } } -const XGO_TOOLCHAIN_VERSION = "XGO_TOOLCHAIN_VERSION" -const XGO_TOOLCHAIN_REVISION = "XGO_TOOLCHAIN_REVISION" -const XGO_TOOLCHAIN_VERSION_NUMBER = "XGO_TOOLCHAIN_VERSION_NUMBER" - -const XGO_VERSION = "XGO_VERSION" -const XGO_REVISION = "XGO_REVISION" -const XGO_NUMBER = "XGO_NUMBER" - func patchVersions(fileList []*syntax.File) { pkgPath := xgo_ctxt.GetPkgPath() if pkgPath != xgo_ctxt.XgoRuntimeCorePkg { @@ -249,26 +267,6 @@ func patchVersions(fileList []*syntax.File) { } } -func shouldTrap() bool { - if xgo_ctxt.SkipPackageTrap() { - return false - } - - pkgPath := xgo_ctxt.GetPkgPath() - if pkgPath == "" || strings.HasPrefix(pkgPath, "runtime/") || strings.HasPrefix(pkgPath, "internal/") || isSkippableSpecialPkg() { - // runtime/internal should not be rewritten - // internal/api has problem with the function register - return false - } - if strings.HasPrefix(pkgPath, xgoRuntimePkgPrefix) { - // skip all xgo runtime pkgs except test - if !strings.HasPrefix(pkgPath[len(xgoRuntimePkgPrefix):], "test/") { - return false - } - } - return true -} - func getFileIndexMapping(files []*syntax.File) map[*syntax.File]int { m := make(map[*syntax.File]int, len(files)) for i, file := range files { @@ -302,9 +300,31 @@ type FileDecl struct { File *syntax.File Funcs []*DeclInfo } +type DeclKind int + +const ( + Kind_Func DeclKind = 0 + Kind_Var DeclKind = 1 + Kind_VarPtr DeclKind = 2 + Kind_Const DeclKind = 3 + + // TODO + // Kind_Interface VarKind = 4 +) + +func (c DeclKind) IsFunc() bool { + return c == Kind_Func +} + +func (c DeclKind) IsVarOrConst() bool { + return c == Kind_Var || c == Kind_VarPtr || c == Kind_Const +} type DeclInfo struct { FuncDecl *syntax.FuncDecl + VarDecl *syntax.VarDecl + ConstDecl *syntax.ConstDecl + Kind DeclKind Name string RecvTypeName string RecvPtr bool @@ -334,6 +354,9 @@ func (c *DeclInfo) RefName() string { return "nil" } // if c.Generic, then the ref name is for generic + if !c.Kind.IsFunc() { + return c.Name + } return xgo_func_name.FormatFuncRefName(c.RecvTypeName, c.RecvPtr, c.Name) } @@ -348,6 +371,12 @@ func (c *DeclInfo) IdentityName() string { if c.Interface { return c.RecvTypeName } + if !c.Kind.IsFunc() { + if c.Kind == Kind_VarPtr { + return "*" + c.Name + } + return c.Name + } return xgo_func_name.FormatFuncRefName(c.RecvTypeName, c.RecvPtr, c.Name) } @@ -374,13 +403,13 @@ func fillName(field *syntax.Field, namePrefix string) { // collect funcs from files, register each of them by // calling to __xgo_reg_func with names and func pointer -func getFuncDecls(files []*syntax.File) []*DeclInfo { +func getFuncDecls(files []*syntax.File, varTrap bool) []*DeclInfo { // fileInfos := make([]*FileDecl, 0, len(files)) var declFuncs []*DeclInfo for i, f := range files { file := f.Pos().RelFilename() for _, decl := range f.DeclList { - fnDecls := extractFuncDecls(i, f, file, decl) + fnDecls := extractFuncDecls(i, f, file, decl, varTrap) declFuncs = append(declFuncs, fnDecls...) } } @@ -399,7 +428,7 @@ func filterFuncDecls(funcDecls []*DeclInfo, pkgPath string) []*DeclInfo { return filtered } -func extractFuncDecls(fileIndex int, f *syntax.File, file string, decl syntax.Decl) []*DeclInfo { +func extractFuncDecls(fileIndex int, f *syntax.File, file string, decl syntax.Decl, varTrap bool) []*DeclInfo { switch decl := decl.(type) { case *syntax.FuncDecl: info := getFuncDeclInfo(fileIndex, f, file, decl) @@ -407,6 +436,32 @@ func extractFuncDecls(fileIndex int, f *syntax.File, file string, decl syntax.De return nil } return []*DeclInfo{info} + case *syntax.VarDecl: + if !varTrap { + return nil + } + varDecls := collectVarDecls(Kind_Var, decl.NameList, decl.Type) + for _, varDecl := range varDecls { + varDecl.VarDecl = decl + + varDecl.FileSyntax = f + varDecl.FileIndex = fileIndex + varDecl.File = file + } + return varDecls + case *syntax.ConstDecl: + if !varTrap { + return nil + } + constDecls := collectVarDecls(Kind_Const, decl.NameList, decl.Type) + for _, constDecl := range constDecls { + constDecl.ConstDecl = decl + + constDecl.FileSyntax = f + constDecl.FileIndex = fileIndex + constDecl.File = file + } + return constDecls case *syntax.TypeDecl: if decl.Alias { return nil @@ -538,36 +593,52 @@ func generateFuncRegBody(funcDecls []*DeclInfo, xgoRegFunc string, xgoLocalFuncS // there are function with name "_" continue } - var refName string = "nil" - if !funcDecl.Generic { - refName = funcDecl.RefName() + var fnRefName string = "nil" + var varRefName string = "nil" + if funcDecl.Kind.IsFunc() { + if !funcDecl.Generic { + fnRefName = funcDecl.RefName() + } + } else if funcDecl.Kind == Kind_Var { + varRefName = "&" + funcDecl.RefName() + } else if funcDecl.Kind == Kind_Const { + varRefName = funcDecl.RefName() } fileIdx := funcDecl.FileIndex fileRef := getFileRef(fileIdx) // func(pkgPath string, fn interface{}, recvTypeName string, recvPtr bool, name string, identityName string, generic bool, recvName string, argNames []string, resNames []string, firstArgCtx bool, lastResErr bool, file string, line int) // check __xgo_local_func_stub for correctness - fieldList := []string{ - "__xgo_regPkgPath", // PkgPath - refName, // Fn - "0", // PC, filled later - strconv.FormatBool(funcDecl.Interface), // Interface - strconv.FormatBool(funcDecl.Generic), // Generic - strconv.FormatBool(funcDecl.Closure), // Closure - strconv.Quote(funcDecl.RecvTypeName), // RecvTypeName - strconv.FormatBool(funcDecl.RecvPtr), // RecvPtr - strconv.Quote(funcDecl.Name), // Name - strconv.Quote(funcDecl.IdentityName()), // IdentityName - strconv.Quote(funcDecl.RecvName), // RecvName - quoteNamesExpr(funcDecl.ArgNames), // ArgNames - quoteNamesExpr(funcDecl.ResNames), // ResNames - strconv.FormatBool(funcDecl.FirstArgCtx), // FirstArgCtx - strconv.FormatBool(funcDecl.LastResError), // LastResErr - fileRef, /* declFunc.FileRef */ // File - strconv.FormatInt(int64(funcDecl.Line), 10), // Line - } - fields := strings.Join(fieldList, ",") - stmts = append(stmts, fmt.Sprintf("%s(%s{%s})", xgoRegFunc, xgoLocalFuncStub, fields)) + regKind := func(kind DeclKind, identityName string) { + fieldList := []string{ + XgoLocalPkgName, // PkgPath + strconv.FormatInt(int64(kind), 10), // Kind + fnRefName, // Fn + varRefName, // Var + "0", // PC, filled later + strconv.FormatBool(funcDecl.Interface), // Interface + strconv.FormatBool(funcDecl.Generic), // Generic + strconv.FormatBool(funcDecl.Closure), // Closure + strconv.Quote(funcDecl.RecvTypeName), // RecvTypeName + strconv.FormatBool(funcDecl.RecvPtr), // RecvPtr + strconv.Quote(funcDecl.Name), // Name + strconv.Quote(identityName), // IdentityName + strconv.Quote(funcDecl.RecvName), // RecvName + quoteNamesExpr(funcDecl.ArgNames), // ArgNames + quoteNamesExpr(funcDecl.ResNames), // ResNames + strconv.FormatBool(funcDecl.FirstArgCtx), // FirstArgCtx + strconv.FormatBool(funcDecl.LastResError), // LastResErr + fileRef, /* declFunc.FileRef */ // File + strconv.FormatInt(int64(funcDecl.Line), 10), // Line + } + fields := strings.Join(fieldList, ",") + stmts = append(stmts, fmt.Sprintf("%s(%s{%s})", xgoRegFunc, xgoLocalFuncStub, fields)) + } + identityName := funcDecl.IdentityName() + regKind(funcDecl.Kind, identityName) + if funcDecl.Kind == Kind_Var { + regKind(Kind_VarPtr, "*"+identityName) + } // add files if !fileDeclaredMapping[fileIdx] { @@ -581,14 +652,13 @@ func generateFuncRegBody(funcDecls []*DeclInfo, xgoRegFunc string, xgoLocalFuncS return "" } allStmts := make([]string, 0, 2+len(fileDefs)+len(stmts)) - allStmts = append(allStmts, `__xgo_regPkgPath := `+strconv.Quote(xgo_ctxt.GetPkgPath())) if false { // debug allStmts = append(allStmts, `__xgo_reg_func_old:=__xgo_reg_func; __xgo_reg_func = func(info interface{}){ - fmt.Print("reg:"+__xgo_regPkgPath+"\n") + fmt.Print("reg:"+`+XgoLocalPkgName+`+"\n") v := reflect.ValueOf(info) if v.Kind() != reflect.Struct { - panic("non struct:"+__xgo_regPkgPath) + panic("non struct:"+`+XgoLocalPkgName+`) } __xgo_reg_func_old(info) }`) diff --git a/patch/syntax/vars.go b/patch/syntax/vars.go new file mode 100644 index 00000000..9d60ccbc --- /dev/null +++ b/patch/syntax/vars.go @@ -0,0 +1,532 @@ +package syntax + +import ( + "cmd/compile/internal/base" + "cmd/compile/internal/syntax" + xgo_ctxt "cmd/compile/internal/xgo_rewrite_internal/patch/ctxt" + "fmt" + "os" + "strconv" + "strings" +) + +func allowVarTrap() bool { + pkgPath := xgo_ctxt.GetPkgPath() + return allowPkgVarTrap(pkgPath) +} + +func allowPkgVarTrap(pkgPath string) bool { + // prevent all std variables + if base.Flag.Std { + return false + } + mainModule := xgo_ctxt.XgoMainModule + if mainModule == "" { + return false + } + + if strings.HasPrefix(pkgPath, mainModule) && (len(pkgPath) == len(mainModule) || pkgPath[len(mainModule)] == '/') { + return true + } + return false +} + +func collectVarDecls(declKind DeclKind, names []*syntax.Name, typ syntax.Expr) []*DeclInfo { + var decls []*DeclInfo + for _, name := range names { + line := name.Pos().Line() + decls = append(decls, &DeclInfo{ + Kind: declKind, + Name: name.Value, + + Line: int(line), + }) + } + return decls +} + +type vis struct { +} + +var _ syntax.Visitor = (*vis)(nil) + +// Visit implements syntax.Visitor. +func (c *vis) Visit(node syntax.Node) (w syntax.Visitor) { + return nil +} + +func trapVariables(fileList []*syntax.File, funcDelcs []*DeclInfo) { + names := make(map[string]*DeclInfo, len(funcDelcs)) + for _, funcDecl := range funcDelcs { + names[funcDecl.IdentityName()] = funcDecl + } + // iterate each file, find variable reference, + for _, file := range fileList { + imports := getImports(file) + for _, decl := range file.DeclList { + fnDecl, ok := decl.(*syntax.FuncDecl) + if !ok { + continue + } + if fnDecl.Body == nil { + continue + } + ctx := &BlockContext{} + ctx.traverseNode(fnDecl.Body, names, imports) + } + } +} +func getImports(file *syntax.File) map[string]string { + imports := make(map[string]string) + for _, decl := range file.DeclList { + impDecl, ok := decl.(*syntax.ImportDecl) + if !ok { + continue + } + pkgPath, err := strconv.Unquote(impDecl.Path.Value) + if err != nil { + continue + } + var localName string + if impDecl.LocalPkgName != nil { + localName = impDecl.LocalPkgName.Value + } else { + idx := strings.LastIndex(pkgPath, "/") + if idx < 0 { + localName = pkgPath + } else { + localName = pkgPath[idx+1:] + } + } + if localName == "" || localName == "." || localName == "_" { + continue + } + imports[localName] = pkgPath + } + return imports +} + +type BlockContext struct { + Parent *BlockContext + Block *syntax.BlockStmt + Index int + + Children []*BlockContext + + Names map[string]bool + + // to be inserted + InsertList []syntax.Stmt + + TrapNames []*NameAndDecl +} + +type NameAndDecl struct { + TakeAddr bool + Name *syntax.Name + Decl *DeclInfo +} + +func (c *BlockContext) Add(name string) { + if c.Names == nil { + c.Names = make(map[string]bool, 1) + } + c.Names[name] = true +} +func (c *BlockContext) Has(name string) bool { + if c == nil { + return false + } + _, ok := c.Names[name] + if ok { + return true + } + return c.Parent.Has(name) +} + +// imports: name -> pkgPath +func (ctx *BlockContext) traverseNode(node syntax.Node, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Node { + if node == nil { + return nil + } + switch node := node.(type) { + case syntax.Stmt: + return ctx.traverseStmt(node, globaleNames, imports) + case syntax.Expr: + return ctx.traverseExpr(node, globaleNames, imports) + case *syntax.CaseClause: + return ctx.traverseCaseClause(node, globaleNames, imports) + case *syntax.CommClause: + return ctx.traverseCommonClause(node, globaleNames, imports) + case *syntax.Field: + // ignore + default: + // unknown + if os.Getenv("XGO_DEBUG_VAR_TRAP_LOOSE") != "true" { + panic(fmt.Errorf("unrecognized node: %T", node)) + } + } + return node +} + +func (ctx *BlockContext) traverseStmt(node syntax.Stmt, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Stmt { + if node == nil { + return nil + } + switch node := node.(type) { + case syntax.SimpleStmt: + return ctx.traverseSimpleStmt(node, globaleNames, imports) + case *syntax.BlockStmt: + return ctx.traverseBlockStmt(node, globaleNames, imports) + case *syntax.CallStmt: + // defer, go + node.Call = ctx.traverseExpr(node.Call, globaleNames, imports) + return node + case *syntax.IfStmt: + node.Init = ctx.traverseSimpleStmt(node.Init, globaleNames, imports) + node.Cond = ctx.traverseExpr(node.Cond, globaleNames, imports) + node.Then = ctx.traverseBlockStmt(node.Then, globaleNames, imports) + node.Else = ctx.traverseStmt(node.Else, globaleNames, imports) + return node + case *syntax.ForStmt: + node.Init = ctx.traverseSimpleStmt(node.Init, globaleNames, imports) + node.Cond = ctx.traverseExpr(node.Cond, globaleNames, imports) + node.Post = ctx.traverseSimpleStmt(node.Post, globaleNames, imports) + node.Body = ctx.traverseBlockStmt(node.Body, globaleNames, imports) + case *syntax.SwitchStmt: + node.Init = ctx.traverseSimpleStmt(node.Init, globaleNames, imports) + node.Tag = ctx.traverseExpr(node.Tag, globaleNames, imports) + for i, clause := range node.Body { + node.Body[i] = ctx.traverseCaseClause(clause, globaleNames, imports) + } + case *syntax.SelectStmt: + for i, clause := range node.Body { + node.Body[i] = ctx.traverseCommonClause(clause, globaleNames, imports) + } + case *syntax.DeclStmt: + for i, decl := range node.DeclList { + node.DeclList[i] = ctx.traverseDecl(decl, globaleNames, imports) + } + case *syntax.LabeledStmt: + node.Stmt = ctx.traverseStmt(node.Stmt, globaleNames, imports) + case *syntax.BranchStmt: + // ignore continue or continue label + case *syntax.ReturnStmt: + node.Results = ctx.traverseExpr(node.Results, globaleNames, imports) + default: + // unknown + if os.Getenv("XGO_DEBUG_VAR_TRAP_LOOSE") != "true" { + panic(fmt.Errorf("unrecognized stmt: %T", node)) + } + } + return node +} + +func (ctx *BlockContext) traverseSimpleStmt(node syntax.SimpleStmt, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.SimpleStmt { + if node == nil { + return nil + } + switch node := node.(type) { + case *syntax.ExprStmt: + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + return node + case *syntax.SendStmt: + node.Chan = ctx.traverseExpr(node.Chan, globaleNames, imports) + node.Value = ctx.traverseExpr(node.Value, globaleNames, imports) + case *syntax.AssignStmt: + if node.Op == syntax.Def { + // add name to current scope + if name, ok := node.Lhs.(*syntax.Name); ok { + ctx.Add(name.Value) + } else if names, ok := node.Lhs.(*syntax.ListExpr); ok { + for _, elem := range names.ElemList { + if name, ok := elem.(*syntax.Name); ok { + ctx.Add(name.Value) + } + } + } + } + node.Rhs = ctx.traverseExpr(node.Rhs, globaleNames, imports) + case *syntax.RangeClause: + if node.Lhs != nil && node.Def { + var fakeAssign syntax.Stmt = &syntax.AssignStmt{ + Op: syntax.Def, + Lhs: node.Lhs, + } + ctx.traverseStmt(fakeAssign, globaleNames, imports) + } + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + case *syntax.EmptyStmt: + // nothing + default: + // unknown + if os.Getenv("XGO_DEBUG_VAR_TRAP_LOOSE") != "true" { + panic(fmt.Errorf("unrecognized simple stmt: %T", node)) + } + } + return node +} + +func (ctx *BlockContext) traverseBlockStmt(node *syntax.BlockStmt, globaleNames map[string]*DeclInfo, imports map[string]string) *syntax.BlockStmt { + if node == nil { + return nil + } + n := len(node.List) + for i := 0; i < n; i++ { + subCtx := &BlockContext{ + Parent: ctx, + Block: node, + Index: i, + } + ctx.Children = append(ctx.Children, subCtx) + node.List[i] = subCtx.traverseStmt(node.List[i], globaleNames, imports) + } + for i := n - 1; i >= 0; i-- { + node.List = insertBefore(node.List, i, ctx.Children[i].InsertList) + } + return node +} + +func (ctx *BlockContext) traverseCaseClause(node *syntax.CaseClause, globaleNames map[string]*DeclInfo, imports map[string]string) *syntax.CaseClause { + if node == nil { + return nil + } + node.Cases = ctx.traverseExpr(node.Cases, globaleNames, imports) + fakeBlock := &syntax.BlockStmt{ + List: node.Body, + } + fakeBlock = ctx.traverseBlockStmt(fakeBlock, globaleNames, imports) + node.Body = fakeBlock.List + return node +} + +func (ctx *BlockContext) traverseCommonClause(node *syntax.CommClause, globaleNames map[string]*DeclInfo, imports map[string]string) *syntax.CommClause { + if node == nil { + return nil + } + node.Comm = ctx.traverseSimpleStmt(node.Comm, globaleNames, imports) + fakeBlock := &syntax.BlockStmt{ + List: node.Body, + } + fakeBlock = ctx.traverseBlockStmt(fakeBlock, globaleNames, imports) + node.Body = fakeBlock.List + return node +} + +func (ctx *BlockContext) traverseExpr(node syntax.Expr, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Expr { + if node == nil { + return nil + } + + switch node := node.(type) { + case *syntax.Name: + return ctx.trapValueNode(node, globaleNames) + case *syntax.CompositeLit: + for i, e := range node.ElemList { + node.ElemList[i] = ctx.traverseExpr(e, globaleNames, imports) + } + case *syntax.KeyValueExpr: + node.Value = ctx.traverseExpr(node.Value, globaleNames, imports) + case *syntax.FuncLit: + subCtx := &BlockContext{ + Parent: ctx, + } + // TODO: add names of types + ctx.Children = append(ctx.Children, subCtx) + node.Body = subCtx.traverseBlockStmt(node.Body, globaleNames, imports) + return node + case *syntax.ParenExpr: + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + case *syntax.SelectorExpr: + return ctx.trapSelector(node, node, false, globaleNames, imports) + case *syntax.IndexExpr: + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + node.Index = ctx.traverseExpr(node.Index, globaleNames, imports) + case *syntax.SliceExpr: + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + for i := 0; i < len(node.Index); i++ { + node.Index[i] = ctx.traverseExpr(node.Index[i], globaleNames, imports) + } + case *syntax.AssertExpr: + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + case *syntax.TypeSwitchGuard: + res := ctx.traverseExpr(node.X, globaleNames, imports) + if node.Lhs != nil { + ctx.Add(node.Lhs.Value) + } + return res + case *syntax.Operation: + // take addr? + if node.Op != syntax.And || node.Y != nil { + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + node.Y = ctx.traverseExpr(node.Y, globaleNames, imports) + return node + } + // &a, + switch x := node.X.(type) { + case *syntax.Name: + return ctx.trapAddrNode(node, x, globaleNames) + case *syntax.SelectorExpr: + return ctx.trapSelector(node, x, true, globaleNames, imports) + default: + node.X = ctx.traverseExpr(node.X, globaleNames, imports) + node.Y = ctx.traverseExpr(node.Y, globaleNames, imports) + } + case *syntax.CallExpr: + // NOTE: we skip capturing a name as a function + // node.Fun = ctx.traverseExpr(node.Fun, globaleNames, imports) + for i, arg := range node.ArgList { + node.ArgList[i] = ctx.traverseExpr(arg, globaleNames, imports) + } + case *syntax.ListExpr: + for i, elem := range node.ElemList { + node.ElemList[i] = ctx.traverseExpr(elem, globaleNames, imports) + } + // the following are ignored + case *syntax.ArrayType: + case *syntax.SliceType: + case *syntax.DotsType: + case *syntax.StructType: + case *syntax.InterfaceType: + case *syntax.FuncType: + case *syntax.ChanType: + case *syntax.MapType: + case *syntax.BasicLit: + case *syntax.BadExpr: + default: + // unknown + if os.Getenv("XGO_DEBUG_VAR_TRAP_LOOSE") != "true" { + panic(fmt.Errorf("unrecognized expr: %T", node)) + } + } + return node +} + +func (ctx *BlockContext) traverseDecl(node syntax.Decl, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Decl { + if node == nil { + return nil + } + switch node := node.(type) { + case *syntax.ConstDecl: + case *syntax.TypeDecl: + case *syntax.VarDecl: + default: + // unknown + if os.Getenv("XGO_DEBUG_VAR_TRAP_LOOSE") != "true" { + panic(fmt.Errorf("unrecognized stmt: %T", node)) + } + } + return node +} + +func (c *BlockContext) trapValueNode(node *syntax.Name, globaleNames map[string]*DeclInfo) syntax.Expr { + name := node.Value + if c.Has(name) { + return node + } + // TODO: what about dot import? + decl := globaleNames[name] + if decl == nil || !decl.Kind.IsVarOrConst() { + return node + } + preStmts, tmpVarName := trapVar(node, syntax.NewName(node.Pos(), XgoLocalPkgName), node.Value, false) + c.InsertList = append(c.InsertList, preStmts...) + return syntax.NewName(node.Pos(), tmpVarName) +} + +func (ctx *BlockContext) trapSelector(node syntax.Expr, sel *syntax.SelectorExpr, takeAddr bool, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Expr { + // form: pkg.var + nameNode, ok := sel.X.(*syntax.Name) + if !ok { + sel.X = ctx.traverseExpr(sel.X, globaleNames, imports) + return node + } + name := nameNode.Value + if ctx.Has(name) { + // local name + sel.X = ctx.traverseExpr(sel.X, globaleNames, imports) + return node + } + // import path + pkgPath := imports[name] + if pkgPath == "" { + sel.X = ctx.trapValueNode(nameNode, globaleNames) + return node + } + if !allowPkgVarTrap(pkgPath) { + return node + } + preStmts, tmpVarName := trapVar(node, newStringLit(pkgPath), sel.Sel.Value, takeAddr) + ctx.InsertList = append(ctx.InsertList, preStmts...) + return syntax.NewName(node.Pos(), tmpVarName) +} + +func (c *BlockContext) trapAddrNode(node *syntax.Operation, nameNode *syntax.Name, globaleNames map[string]*DeclInfo) syntax.Expr { + name := nameNode.Value + if c.Has(name) { + return node + } + // TODO: what about dot import? + decl := globaleNames[name] + if decl == nil || !decl.Kind.IsVarOrConst() { + return node + } + preStmts, tmpVarName := trapVar(node, syntax.NewName(nameNode.Pos(), XgoLocalPkgName), name, true) + c.InsertList = append(c.InsertList, preStmts...) + return syntax.NewName(node.Pos(), tmpVarName) +} + +func trapVar(expr syntax.Expr, pkgRef syntax.Expr, name string, takeAddr bool) (preStmts []syntax.Stmt, tmpVarName string) { + pos := expr.Pos() + line := pos.Line() + col := pos.Col() + + // a.b: + // ___m := a;__trap_var(&__m, &__a); + // __m.b + + // &a: + // __m:=&a; __trap_var(pkg,"a", &__m,takeAddr=true) + // &a -> __m + varName := fmt.Sprintf("__xgo_%s_%d_%d", name, line, col) + // a: + + preStmts = append(preStmts, &syntax.AssignStmt{ + Op: syntax.Def, + Lhs: syntax.NewName(pos, varName), + Rhs: expr, + }, + &syntax.ExprStmt{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(pos, "__xgo_link_trap_var_for_generated"), + ArgList: []syntax.Expr{ + pkgRef, + newStringLit(name), + &syntax.Operation{ + Op: syntax.And, + X: syntax.NewName(pos, varName), + }, + newBool(pos, takeAddr), + }, + }, + }, + // &syntax.ExprStmt{ + // X: &syntax.CallExpr{ + // Fun: syntax.NewName(pos, "panic"), + // ArgList: []syntax.Expr{ + // newStringLit(fmt.Sprintf("%s := %s; __trap_var(&%s, &%s)", varName, name.Value, varName, name.Value)), + // }, + // }, + // }, + ) + + for _, preStmt := range preStmts { + fillPos(pos, preStmt) + } + return preStmts, varName + +} + +func insertBefore(list []syntax.Stmt, i int, add []syntax.Stmt) []syntax.Stmt { + return append(append(list[:i:i], add...), list[i:]...) +} diff --git a/patch/trap.go b/patch/trap.go index 40f8674a..a6327cf3 100644 --- a/patch/trap.go +++ b/patch/trap.go @@ -192,6 +192,8 @@ func InsertTrapForFunc(fn *ir.Func, forGeneric bool) bool { if fn.OClosure == nil { return false } + + // register closure if isClosureWrapperForGeneric(fn) { // skip trap for generic closures, // but register for info access @@ -199,7 +201,7 @@ func InsertTrapForFunc(fn *ir.Func, forGeneric bool) bool { return false } isClosure = true - } else if decl.Interface || decl.Generic { + } else if decl.Interface || decl.Generic || !decl.Kind.IsFunc() { // interface just name return false } @@ -331,14 +333,6 @@ func CanInsertTrapOrLink(fn *ir.Func) (string, bool) { return "", false } - // skip all packages for xgo,except test - if strings.HasPrefix(pkgPath, xgoRuntimePkgPrefix) { - remain := pkgPath[len(xgoRuntimePkgPrefix):] - if !strings.HasPrefix(remain, "test/") && !strings.HasPrefix(remain, "runtime/test/") { - return "", false - } - } - // check if function body's first statement is a call to 'trap.Skip()' if isFirstStmtSkipTrap(fn.Body) { return "", false diff --git a/patch/trap_runtime/xgo_trap.go b/patch/trap_runtime/xgo_trap.go index 92a956d0..45c27a98 100644 --- a/patch/trap_runtime/xgo_trap.go +++ b/patch/trap_runtime/xgo_trap.go @@ -49,6 +49,15 @@ func __xgo_trap_for_generated(pkgPath string, pc uintptr, identityName string, g return __xgo_trap_impl(pkgPath, identityName, generic, fn.entry() /*>=go1.18*/, recv, args, results) } +var __xgo_trap_var_impl func(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) + +func __xgo_trap_var_for_generated(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) { + if __xgo_trap_var_impl == nil { + return + } + __xgo_trap_var_impl(pkgPath, name, tmpVarAddr, takeAddr) +} + func __xgo_set_trap(trap func(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool)) { if __xgo_trap_impl != nil { panic("trap already set by other packages") @@ -58,6 +67,15 @@ func __xgo_set_trap(trap func(pkgPath string, identityName string, generic bool, __xgo_trap_impl = trap } +func __xgo_set_trap_var(trap func(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool)) { + if __xgo_trap_var_impl != nil { + panic("trap var already set by other packages") + } + // ensure this init is called before main + // we do not care init here, we try our best + __xgo_trap_var_impl = trap +} + // NOTE: runtime has problem when using slice var __xgo_registered_func_infos []interface{} diff --git a/runtime/core/func.go b/runtime/core/func.go index 3608f032..b1b8bf2e 100644 --- a/runtime/core/func.go +++ b/runtime/core/func.go @@ -1,14 +1,40 @@ package core import ( + "fmt" "strings" ) const __XGO_SKIP_TRAP = true +type Kind int + +const ( + Kind_Func Kind = 0 + Kind_Var Kind = 1 + Kind_VarPtr Kind = 2 + Kind_Const Kind = 3 +) + +func (c Kind) String() string { + switch c { + case Kind_Func: + return "func" + case Kind_Var: + return "var" + case Kind_VarPtr: + return "var_ptr" + case Kind_Const: + return "const" + default: + return fmt.Sprintf("kind_%d", int(c)) + } +} + type FuncInfo struct { // full name, format: {pkgPath}.{receiver}.{funcName} // example: github.com/xhd2015/xgo/runtime/core.(*FuncInfo).IsFunc + Kind Kind FullName string Pkg string IdentityName string @@ -29,8 +55,10 @@ type FuncInfo struct { File string Line int - PC uintptr `json:"-"` - Func interface{} `json:"-"` + PC uintptr `json:"-"` + Func interface{} `json:"-"` + Var interface{} `json:"-"` // var address + RecvName string ArgNames []string ResNames []string diff --git a/runtime/core/version.go b/runtime/core/version.go index cf856700..7fe8a935 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -7,8 +7,8 @@ import ( ) const VERSION = "1.0.18" -const REVISION = "03d82b3e31832e5947c5d3a7ef8752f4f39db28c+1" -const NUMBER = 162 +const REVISION = "1211c519c8005ddbd66189cf64e958aa69e5789f+1" +const NUMBER = 163 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/functab/functab.go b/runtime/functab/functab.go index 2bc1bde5..70cec131 100644 --- a/runtime/functab/functab.go +++ b/runtime/functab/functab.go @@ -37,6 +37,7 @@ func __xgo_link_get_pc_name(pc uintptr) string { var funcInfos []*core.FuncInfo var funcInfoMapping map[string]map[string]*core.FuncInfo // pkg -> identifyName -> FuncInfo var funcPCMapping map[uintptr]*core.FuncInfo // pc->FuncInfo +var varAddrMapping map[uintptr]*core.FuncInfo // addr->FuncInfo var funcFullNameMapping map[string]*core.FuncInfo // fullName -> FuncInfo var interfaceMapping map[string]map[string]*core.FuncInfo // pkg -> interfaceName -> FuncInfo var typeMethodMapping map[reflect.Type]map[string]*core.FuncInfo // reflect.Type -> interfaceName -> FuncInfo @@ -56,6 +57,15 @@ func InfoFunc(fn interface{}) *core.FuncInfo { pc := v.Pointer() return funcPCMapping[pc] } +func InfoVar(addr interface{}) *core.FuncInfo { + ensureMapping() + v := reflect.ValueOf(addr) + if v.Kind() != reflect.Pointer { + panic(fmt.Errorf("given type is not a pointer: %T", addr)) + } + ptr := v.Pointer() + return varAddrMapping[ptr] +} // maybe rename to FuncForPC func InfoPC(pc uintptr) *core.FuncInfo { @@ -144,6 +154,7 @@ func ensureMapping() { funcInfoMapping = make(map[string]map[string]*core.FuncInfo) funcFullNameMapping = make(map[string]*core.FuncInfo) interfaceMapping = make(map[string]map[string]*core.FuncInfo) + varAddrMapping = make(map[uintptr]*core.FuncInfo) __xgo_link_retrieve_all_funcs_and_clear(func(fnInfo interface{}) { rv := reflect.ValueOf(fnInfo) if rv.Kind() != reflect.Struct { @@ -157,7 +168,12 @@ func ensureMapping() { } // fmt.Fprintf(os.Stderr, "empty name\n",pkgPath) } - + var fnKind core.Kind + fnKindV := rv.FieldByName("Kind") + if fnKindV.IsValid() { + fnKind = core.Kind(fnKindV.Int()) + } + varField := rv.FieldByName("Var") pkgPath := rv.FieldByName("PkgPath").String() recvTypeName := rv.FieldByName("RecvTypeName").String() recvPtr := rv.FieldByName("RecvPtr").Bool() @@ -188,7 +204,7 @@ func ensureMapping() { pc = getFuncPC(f) fullName = __xgo_link_get_pc_name(pc) } else { - if closure && identityName != "" { + if (closure || fnKind == core.Kind_Var || fnKind == core.Kind_VarPtr || fnKind == core.Kind_Const) && identityName != "" { fullName = pkgPath + "." + identityName } } @@ -206,6 +222,7 @@ func ensureMapping() { // } // _, recvTypeName, recvPtr, name := core.ParseFuncName(identityName, false) info := &core.FuncInfo{ + Kind: fnKind, FullName: fullName, Pkg: pkgPath, IdentityName: identityName, @@ -232,8 +249,11 @@ func ensureMapping() { FirstArgCtx: firstArgCtx, LastResultErr: lastResErr, } + if varField.IsValid() { + info.Var = varField.Interface() + } funcInfos = append(funcInfos, info) - if !generic { + if !generic && info.PC != 0 { funcPCMapping[info.PC] = info } if identityName != "" { @@ -252,6 +272,12 @@ func ensureMapping() { } pkgMapping[recvTypeName] = info } + if fnKind == core.Kind_Var { + if varField.IsValid() { + varAddr := varField.Elem().Pointer() + varAddrMapping[varAddr] = info + } + } if fullName != "" { funcFullNameMapping[fullName] = info } diff --git a/runtime/mock/mock.go b/runtime/mock/mock.go index 223a3b80..296e79af 100644 --- a/runtime/mock/mock.go +++ b/runtime/mock/mock.go @@ -47,7 +47,11 @@ func getFunc(fn interface{}) (recvPtr interface{}, fnInfo *core.FuncInfo, funcPC recvPtr, fnInfo, funcPC, trappingPC = trap.InspectPC(fn) if fnInfo == nil { pc := reflect.ValueOf(fn).Pointer() - panic(fmt.Errorf("failed to setup mock for: %v", runtime.FuncForPC(pc).Name())) + fn := runtime.FuncForPC(pc) + if fn == nil { + panic(fmt.Errorf("failed to setup mock for variable: 0x%x", pc)) + } + panic(fmt.Errorf("failed to setup mock for: %v", fn.Name())) } return recvPtr, fnInfo, funcPC, trappingPC } @@ -96,7 +100,7 @@ func AddFuncInterceptor(fn interface{}, interceptor Interceptor) func() { func mock(mockRecvPtr interface{}, mockFnInfo *core.FuncInfo, funcPC uintptr, trappingPC uintptr, interceptor Interceptor) func() { return trap.AddInterceptor(&trap.Interceptor{ Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { - if f.PC == 0 { + if f.Kind == core.Kind_Func && f.PC == 0 { if !f.Generic { if !f.Closure || trap.ClosureHasFunc { return nil, nil diff --git a/runtime/mock/patch.go b/runtime/mock/patch.go index 6b866ec8..8de02522 100644 --- a/runtime/mock/patch.go +++ b/runtime/mock/patch.go @@ -22,11 +22,19 @@ func Patch(fn interface{}, replacer interface{}) func() { panic("replacer cannot be nil") } fnType := reflect.TypeOf(fn) - if fnType.Kind() != reflect.Func { - panic(fmt.Errorf("fn should be func, actual: %T", fn)) - } - if fnType != reflect.TypeOf(replacer) { - panic(fmt.Errorf("replacer should have type: %T, actual: %T", fn, replacer)) + fnKind := fnType.Kind() + if fnKind == reflect.Func { + if fnType != reflect.TypeOf(replacer) { + panic(fmt.Errorf("replacer should have type: %T, actual: %T", fn, replacer)) + } + } else if fnKind == reflect.Pointer { + if reflect.TypeOf(replacer).Kind() != reflect.Func { + // TODO: validate return value + t := fnType.Elem().String() + panic(fmt.Errorf("replacer should be a func()%s or func()*%s, actual: %T", t, t, replacer)) + } + } else { + panic(fmt.Errorf("fn should be func or pointer to variable, actual: %T", fn)) } recvPtr, fnInfo, funcPC, trappingPC := getFunc(fn) @@ -178,6 +186,10 @@ func checkFuncTypeMatch(a reflect.Type, b reflect.Type, skipAFirst bool) (atype return "", "", true } +// func func(fn reflect.Type, in []reflect.Type, out []reflect.Type, vardaric bool, skipAFirst bool) (atype string, btype string, match bool) { + +// } + func formatFuncType(f reflect.Type, skipFirst bool) string { n := f.NumIn() i := 0 diff --git a/runtime/test/debug/debug_test.go b/runtime/test/debug/debug_test.go index 2bfc2c3d..d18d33d6 100644 --- a/runtime/test/debug/debug_test.go +++ b/runtime/test/debug/debug_test.go @@ -6,18 +6,21 @@ package debug import ( + "context" "testing" + "github.com/xhd2015/xgo/runtime/core" "github.com/xhd2015/xgo/runtime/mock" + "github.com/xhd2015/xgo/runtime/test/mock_var/sub" ) -func TestDebug(t *testing.T) { - mock.Patch(greet, func(s string) string { - return "mock " + s +func TestMockVarInOtherPkg(t *testing.T) { + mock.Mock(&sub.A, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + results.GetFieldIndex(0).Set("mockA") + return nil }) - greet("world") -} - -func greet(s string) string { - return "hello " + s + b := sub.A + if b != "mockA" { + t.Fatalf("expect sub.A to be %s, actual: %s", "mockA", b) + } } diff --git a/runtime/test/func_list/func_list_test.go b/runtime/test/func_list/func_list_test.go index a1d53b90..5b9cc87d 100644 --- a/runtime/test/func_list/func_list_test.go +++ b/runtime/test/func_list/func_list_test.go @@ -11,9 +11,9 @@ const testPkgPath = "github.com/xhd2015/xgo/runtime/test/func_list" var addExtraPkgsAssert func(m map[string]bool) -// go run ./script/run-test/ --include go1.17.13 --xgo-runtime-test-only -run TestFuncList -v ./test/func_list -// go run ./cmd/xgo test --project-dir runtime -run TestFuncList -v ./test/func_list -func TestFuncList(t *testing.T) { +// go run ./script/run-test/ --include go1.17.13 --xgo-runtime-test-only -run TestFuncListFn -v ./test/func_list +// go run ./cmd/xgo test --project-dir runtime -run TestFuncListFn -v ./test/func_list +func TestFuncListFn(t *testing.T) { funcs := functab.GetFuncs() missingPkgs := map[string]bool{ diff --git a/runtime/test/func_list/func_list_var_test.go b/runtime/test/func_list/func_list_var_test.go new file mode 100644 index 00000000..7b59efba --- /dev/null +++ b/runtime/test/func_list/func_list_var_test.go @@ -0,0 +1,46 @@ +package func_list + +import ( + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/functab" + "github.com/xhd2015/xgo/runtime/test/func_list/sub" +) + +const a int = 10 + +var b struct{} + +const subPkgPath = testPkgPath + "/sub" + +func TestFuncListVar(t *testing.T) { + fna := functab.Info(testPkgPath, "a") + if fna.Kind != core.Kind_Const { + t.Fatalf("expect a.Kind to be %v, actual: %v", core.Kind_Const, fna.Kind) + } + + fnaPtr := functab.Info(testPkgPath, "*a") + if fnaPtr != nil { + t.Fatalf("expect aptr to be nil, actual: %v", fnaPtr) + } + + fnb := functab.Info(testPkgPath, "b") + if fnb.Kind != core.Kind_Var { + t.Fatalf("expect b.Kind to be %v, actual: %v", core.Kind_Var, fnb.Kind) + } + + fnbPtr := functab.Info(testPkgPath, "*b") + if fnbPtr.Kind != core.Kind_VarPtr { + t.Fatalf("expect bptr.Kind to be %v, actual: %v", core.Kind_VarPtr, fnbPtr.Kind) + } +} + +var _ = sub.A + +func TestFuncListSubPkgVar(t *testing.T) { + fnA := functab.Info(subPkgPath, "A") + if fnA.Kind != core.Kind_Var { + t.Fatalf("expect fnA.Kind to be %v, actual: %v", core.Kind_Var, fnA.Kind) + } +} diff --git a/runtime/test/func_list/sub/sub.go b/runtime/test/func_list/sub/sub.go new file mode 100644 index 00000000..4e5c12f9 --- /dev/null +++ b/runtime/test/func_list/sub/sub.go @@ -0,0 +1,5 @@ +package sub + +var A int = 10 + +var b int = 20 diff --git a/runtime/test/mock_var/mock_var_test.go b/runtime/test/mock_var/mock_var_test.go new file mode 100644 index 00000000..b5dd2d83 --- /dev/null +++ b/runtime/test/mock_var/mock_var_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/mock" + "github.com/xhd2015/xgo/runtime/test/mock_var/sub" +) + +var a int = 123 + +// TODO: support xgo:notrap +// xgo:notrap +var b int + +func TestMockVarTest(t *testing.T) { + mock.Mock(&a, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + results.GetFieldIndex(0).Set(456) + return nil + }) + b := a + if b != 456 { + t.Fatalf("expect b to be %d, actual: %d", 456, b) + } +} + +func TestMockVarInOtherPkg(t *testing.T) { + mock.Mock(&sub.A, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + results.GetFieldIndex(0).Set("mockA") + return nil + }) + b := sub.A + if b != "mockA" { + t.Fatalf("expect sub.A to be %s, actual: %s", "mockA", b) + } +} diff --git a/runtime/test/mock_var/sub/sub.go b/runtime/test/mock_var/sub/sub.go new file mode 100644 index 00000000..3aaa8733 --- /dev/null +++ b/runtime/test/mock_var/sub/sub.go @@ -0,0 +1,3 @@ +package sub + +var A string = "subA" diff --git a/runtime/test/patch/patch_var_test.go b/runtime/test/patch/patch_var_test.go new file mode 100644 index 00000000..f6df416e --- /dev/null +++ b/runtime/test/patch/patch_var_test.go @@ -0,0 +1,19 @@ +package patch + +import ( + "testing" + + "github.com/xhd2015/xgo/runtime/mock" +) + +var a int = 123 + +func TestPatchVarTest(t *testing.T) { + mock.Patch(&a, func() int { + return 456 + }) + b := a + if b != 456 { + t.Fatalf("expect patched varaibel a to be %d, actual: %d", 456, b) + } +} diff --git a/runtime/trap/inspect.go b/runtime/trap/inspect.go index d6d3b815..3766a59f 100644 --- a/runtime/trap/inspect.go +++ b/runtime/trap/inspect.go @@ -24,6 +24,12 @@ func Inspect(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo) { func InspectPC(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo, funcPC uintptr, trappingPC uintptr) { fn := reflect.ValueOf(f) + // try as a variable + if fn.Kind() == reflect.Ptr { + // a variable + funcInfo = functab.InfoVar(f) + return nil, funcInfo, 0, 0 + } if fn.Kind() != reflect.Func { panic(fmt.Errorf("Inspect requires func, given: %s", fn.Kind().String())) } diff --git a/runtime/trap/trap.go b/runtime/trap/trap.go index 8d8d0a95..81a5955f 100644 --- a/runtime/trap/trap.go +++ b/runtime/trap/trap.go @@ -15,7 +15,8 @@ var setupOnce sync.Once func ensureTrapInstall() { setupOnce.Do(func() { - __xgo_link_set_trap(trapImpl) + __xgo_link_set_trap(trapFunc) + __xgo_link_set_trap_var(trapVar) }) } func init() { @@ -40,6 +41,10 @@ func init() { func __xgo_link_set_trap(trapImpl func(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool)) { fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_set_trap(requires xgo).") } + +func __xgo_link_set_trap_var(trap func(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool)) { + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_set_trap_var(requires xgo).") +} func __xgo_link_on_gonewproc(f func(g uintptr)) { fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_gonewproc(requires xgo).") } @@ -57,29 +62,16 @@ var trappingPC sync.Map // -> PC // link to runtime // xgo:notrap -func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { - if isByPassing() { - return nil, false - } - dispose := setTrappingMark() - if dispose == nil { +func trapFunc(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { + interceptors := GetAllInterceptors() + n := len(interceptors) + if n == 0 { return nil, false } - defer dispose() - // setup context setTrappingPC(pc) defer clearTrappingPC() - type intf struct { - _ uintptr - pc *uintptr - } - interceptors := GetAllInterceptors() - n := len(interceptors) - if n == 0 { - return nil, false - } // NOTE: this may return nil for generic template var f *core.FuncInfo if !generic { @@ -90,13 +82,51 @@ func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, rec f = functab.Info(pkgPath, identityName) } if f == nil { - // + // no func found return nil, false } if f.RecvType != "" && methodHasBeenTrapped && recv == nil { + // method // let go to the next interceptor return nil, false } + return trap(f, interceptors, recv, args, results) +} + +func trapVar(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) { + interceptors := GetAllInterceptors() + n := len(interceptors) + if n == 0 { + return + } + identityName := name + if takeAddr { + identityName = "*" + name + } + fnInfo := functab.Info(pkgPath, identityName) + if fnInfo == nil { + return + } + if fnInfo.Kind != core.Kind_Var && fnInfo.Kind != core.Kind_VarPtr && fnInfo.Kind != core.Kind_Const { + return + } + // NOTE: stop always ignored because this is a simple get + post, _ := trap(fnInfo, interceptors, nil, nil, []interface{}{tmpVarAddr}) + if post != nil { + // NOTE: must in defer, because in post we + // may capture panic + defer post() + } +} +func trap(f *core.FuncInfo, interceptors []*Interceptor, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { + if isByPassing() { + return nil, false + } + dispose := setTrappingMark() + if dispose == nil { + return nil, false + } + defer dispose() // retrieve context var ctx context.Context @@ -186,6 +216,7 @@ func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, rec } abortIdx := -1 + n := len(interceptors) dataList := make([]interface{}, n) for i := n - 1; i >= 0; i-- { interceptor := interceptors[i] diff --git a/script/build-compiler/main.go b/script/build-compiler/main.go index 621d0158..8c50c975 100644 --- a/script/build-compiler/main.go +++ b/script/build-compiler/main.go @@ -10,6 +10,7 @@ func main() { args := os.Args[1:] execArgs := []string{ "run", + "-tags", "dev", "./cmd/xgo", "build", "--xgo-src", diff --git a/script/run-test/main.go b/script/run-test/main.go index 58c6135c..7243c2a9 100644 --- a/script/run-test/main.go +++ b/script/run-test/main.go @@ -44,6 +44,7 @@ var runtimeSubTests = []string{ "mock_closure", "mock_stdlib", "mock_generic", + "mock_var", "trap_args", "patch", } @@ -373,7 +374,7 @@ func doRunTest(goroot string, kind testKind, args []string, tests []string) erro testArgs = append(testArgs, "./cmd/...") } case testKind_xgoTest: - testArgs = []string{"run", "./cmd/xgo", "test", "-tags", "dev"} + testArgs = []string{"run", "-tags", "dev", "./cmd/xgo", "test"} testArgs = append(testArgs, args...) if len(tests) > 0 { testArgs = append(testArgs, tests...) @@ -381,7 +382,7 @@ func doRunTest(goroot string, kind testKind, args []string, tests []string) erro testArgs = append(testArgs, "./test/xgo_test/...") } case testKind_runtimeTest: - testArgs = []string{"run", "./cmd/xgo", "test", "--project-dir", "runtime", "-tags", "dev"} + testArgs = []string{"run", "-tags", "dev", "./cmd/xgo", "test", "--project-dir", "runtime"} testArgs = append(testArgs, args...) if len(tests) > 0 { testArgs = append(testArgs, tests...) diff --git a/script/setup-dev/main.go b/script/setup-dev/main.go index 4bfc66a7..f0bbdd70 100644 --- a/script/setup-dev/main.go +++ b/script/setup-dev/main.go @@ -14,6 +14,7 @@ func main() { args := os.Args[1:] execArgs := []string{ "run", + "-tags", "dev", "./cmd/xgo", "build", "--xgo-src", diff --git a/support/goinfo/mod.go b/support/goinfo/mod.go new file mode 100644 index 00000000..4af1cd7a --- /dev/null +++ b/support/goinfo/mod.go @@ -0,0 +1,120 @@ +package goinfo + +import ( + "errors" + "os" + "path/filepath" + "strings" + + "github.com/xhd2015/xgo/support/osinfo" +) + +var ErrGoModNotFound = errors.New("go.mod not found") +var ErrGoModDoesNotHaveModule = errors.New("go.mod does not have module") + +func ResolveMainModule(dir string, args []string) (string, error) { + goMod, _, err := findGoMod(dir) + if err != nil { + return "", err + } + + goModContent, err := os.ReadFile(goMod) + if err != nil { + return "", err + } + modPath := parseModPath(string(goModContent)) + if modPath == "" { + return "", ErrGoModDoesNotHaveModule + } + + return modPath, nil + + // // has quailified name: not starting with ./ or ../ + // var qualifieldNames []string + // for _, arg := range args { + // if !isRelative(arg) { + // qualifieldNames = append(qualifieldNames, arg) + // } else { + + // } + // } + + // return "", nil +} + +func isRelative(arg string) bool { + if arg == "" { + // pwd + return true + } + n := len(arg) + if arg[0] != '.' { + return false + } + if n == 1 || arg[1] == '/' || (osinfo.IS_WINDOWS && arg[1] == '\\') { + // . ./ .\ + return true + } + if arg[1] != '.' { + return false + } + return n == 2 || arg[2] == '/' || (osinfo.IS_WINDOWS && arg[2] == '\\') +} + +func findGoMod(dir string) (file string, subPaths []string, err error) { + var absDir string + if dir == "" { + absDir, err = os.Getwd() + } else { + absDir, err = filepath.Abs(dir) + } + if err != nil { + return "", nil, err + } + iterDir := absDir + init := true + for { + if init { + init = false + } else { + subPaths = append(subPaths, filepath.Base(iterDir)) + nextIterDir := filepath.Dir(iterDir) + if iterDir == string(filepath.Separator) || nextIterDir == iterDir { + // until root + // TODO: what about windows? + return "", nil, ErrGoModNotFound + } + iterDir = nextIterDir + } + file := filepath.Join(iterDir, "go.mod") + stat, err := os.Stat(file) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return "", nil, err + } + continue + } + if stat.IsDir() { + continue + } + // a valid go.mod found + return file, subPaths, nil + } +} + +func parseModPath(goModContent string) string { + lines := strings.Split(string(goModContent), "\n") + n := len(lines) + for i := 0; i < n; i++ { + line := strings.TrimSpace(lines[i]) + if strings.HasPrefix(line, "module ") { + module := strings.TrimSpace(line[len("module "):]) + commentIdx := strings.Index(module, "//") + if commentIdx >= 0 { + module = strings.TrimSpace(module[:commentIdx]) + } + return module + } + } + return "" +} diff --git a/support/goinfo/mod_test.go b/support/goinfo/mod_test.go new file mode 100644 index 00000000..19aba124 --- /dev/null +++ b/support/goinfo/mod_test.go @@ -0,0 +1,29 @@ +package goinfo + +import "testing" + +func TestParseMode(t *testing.T) { + testCases := []struct { + Content string + Module string + }{ + { + ` module a`, + `a`, + }, + { + `module a/bc//yes it me`, + `a/bc`, + }, + { + "go 1.18\r\nmodule a/bc//windows it me\r\n", + `a/bc`, + }, + } + for _, tc := range testCases { + m := parseModPath(tc.Content) + if m != tc.Module { + t.Fatalf("expect parseModPath(%q) to be %q, actual: %q", tc.Content, tc.Module, m) + } + } +} diff --git a/support/osinfo/osinfo_nonwin.go b/support/osinfo/osinfo_nonwin.go index 074cf873..d672bd2a 100644 --- a/support/osinfo/osinfo_nonwin.go +++ b/support/osinfo/osinfo_nonwin.go @@ -8,3 +8,5 @@ const EXE_SUFFIX = "" // when copy files, should use // symbolic as long as possible const FORCE_COPY_UNSYM = false + +const IS_WINDOWS = true diff --git a/support/osinfo/osinfo_win.go b/support/osinfo/osinfo_win.go index dd4f07ff..ccf357db 100644 --- a/support/osinfo/osinfo_win.go +++ b/support/osinfo/osinfo_win.go @@ -8,3 +8,5 @@ const EXE_SUFFIX = ".exe" // when copy files, don't use // symbolic as it may cause failure const FORCE_COPY_UNSYM = true + +const IS_WINDOWS = false