diff --git a/internal/injector/aspect/context/context.go b/internal/injector/aspect/context/context.go index 5eed7744..7ffa0288 100644 --- a/internal/injector/aspect/context/context.go +++ b/internal/injector/aspect/context/context.go @@ -224,8 +224,13 @@ func (c *context) ParseSource(bytes []byte) (*dst.File, error) { return c.sourceParser.Parse(bytes) } -func (c *context) AddImport(path string, alias string) bool { - return c.refMap.AddImport(c.file, path, alias) +func (c *context) AddImport(path string, name string) bool { + nodeChain := []dst.Node{c.node} + for p := c.NodeChain.parent; p != nil; p = p.parent { + nodeChain = append(nodeChain, p.node) + } + + return c.refMap.AddImport(c.file, nodeChain, path, name) } func (c *context) AddLink(path string) bool { diff --git a/internal/injector/check.go b/internal/injector/check.go index 759e4647..61068806 100644 --- a/internal/injector/check.go +++ b/internal/injector/check.go @@ -16,9 +16,12 @@ import ( // typeCheck runs the Go type checker on the provided files, and returns the // Uses type information map that is built in the process. -func (i *Injector) typeCheck(fset *token.FileSet, files []*ast.File) (map[*ast.Ident]types.Object, error) { +func (i *Injector) typeCheck(fset *token.FileSet, files []*ast.File) (types.Info, error) { pkg := types.NewPackage(i.ImportPath, i.Name) - typeInfo := types.Info{Uses: make(map[*ast.Ident]types.Object)} + typeInfo := types.Info{ + Uses: make(map[*ast.Ident]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + } checkerCfg := types.Config{ GoVersion: i.GoVersion, @@ -27,8 +30,8 @@ func (i *Injector) typeCheck(fset *token.FileSet, files []*ast.File) (map[*ast.I checker := types.NewChecker(&checkerCfg, fset, pkg, &typeInfo) if err := checker.Files(files); err != nil { - return nil, fmt.Errorf("type-checking files: %w", err) + return types.Info{}, fmt.Errorf("type-checking files: %w", err) } - return typeInfo.Uses, nil + return typeInfo, nil } diff --git a/internal/injector/injector.go b/internal/injector/injector.go index 8d324ae9..9b77a2f2 100644 --- a/internal/injector/injector.go +++ b/internal/injector/injector.go @@ -14,6 +14,7 @@ import ( "go/ast" "go/importer" "go/token" + "go/types" "sync" "github.com/DataDog/orchestrion/internal/injector/aspect" @@ -84,7 +85,7 @@ func (i *Injector) InjectFiles(files []string) (map[string]InjectedFile, context return nil, context.GoLangVersion{}, err } - uses, err := i.typeCheck(fset, astFiles) + typeInfo, err := i.typeCheck(fset, astFiles) if err != nil { return nil, context.GoLangVersion{}, err } @@ -103,7 +104,7 @@ func (i *Injector) InjectFiles(files []string) (map[string]InjectedFile, context go func(idx int, astFile *ast.File) { defer wg.Done() - decorator := decorator.NewDecoratorWithImports(fset, i.ImportPath, gotypes.New(uses)) + decorator := decorator.NewDecoratorWithImports(fset, i.ImportPath, gotypes.New(typeInfo.Uses)) dstFile, err := decorator.DecorateFile(astFile) if err != nil { errsMu.Lock() @@ -112,7 +113,7 @@ func (i *Injector) InjectFiles(files []string) (map[string]InjectedFile, context return } - res, err := i.injectFile(decorator, dstFile) + res, err := i.injectFile(decorator, dstFile, typeInfo) if err != nil { errsMu.Lock() defer errsMu.Unlock() @@ -152,11 +153,11 @@ func (i *Injector) validate() error { // injectFile injects code in the specified file. This method can be called concurrently by multiple goroutines, // as is guarded by a sync.Mutex. -func (i *Injector) injectFile(decorator *decorator.Decorator, file *dst.File) (result, error) { +func (i *Injector) injectFile(decorator *decorator.Decorator, file *dst.File, typeInfo types.Info) (result, error) { result := result{InjectedFile: InjectedFile{Filename: decorator.Filenames[file]}} var err error - result.Modified, result.References, result.GoLang, err = i.applyAspects(decorator, file, i.RootConfig) + result.Modified, result.References, result.GoLang, err = i.applyAspects(decorator, file, i.RootConfig, typeInfo) if err != nil { return result, err } @@ -171,11 +172,11 @@ func (i *Injector) injectFile(decorator *decorator.Decorator, file *dst.File) (r return result, nil } -func (i *Injector) applyAspects(decorator *decorator.Decorator, file *dst.File, rootConfig map[string]string) (bool, typed.ReferenceMap, context.GoLangVersion, error) { +func (i *Injector) applyAspects(decorator *decorator.Decorator, file *dst.File, rootConfig map[string]string, typeInfo types.Info) (bool, typed.ReferenceMap, context.GoLangVersion, error) { var ( chain *context.NodeChain modified bool - references typed.ReferenceMap + references = typed.NewReferenceMap(decorator.Ast.Nodes, typeInfo.Scopes) err error ) @@ -183,6 +184,7 @@ func (i *Injector) applyAspects(decorator *decorator.Decorator, file *dst.File, if err != nil || csor.Node() == nil || isIgnored(csor.Node()) { return false } + root := chain == nil chain = chain.Child(csor) if root { diff --git a/internal/injector/testdata/injector/chi5-newroute-dotimport/config.yml b/internal/injector/testdata/injector/chi5-newroute-dotimport/config.yml index aac3a0c1..ec2805d0 100644 --- a/internal/injector/testdata/injector/chi5-newroute-dotimport/config.yml +++ b/internal/injector/testdata/injector/chi5-newroute-dotimport/config.yml @@ -16,6 +16,7 @@ aspects: }() syntheticReferences: + github.com/go-chi/chi/v5: true gopkg.in/DataDog/dd-trace-go.v1/contrib/go-chi/chi.v5: true code: |- diff --git a/internal/injector/testdata/injector/import-shadowing/config.yml b/internal/injector/testdata/injector/import-shadowing/config.yml new file mode 100644 index 00000000..7e4c6b20 --- /dev/null +++ b/internal/injector/testdata/injector/import-shadowing/config.yml @@ -0,0 +1,47 @@ +%YAML 1.1 +--- +aspects: + - id: Register + join-point: + function-call: database/sql.Register + advice: + - wrap-expression: + imports: + sqltrace: gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql + sql: database/sql + driver: database/sql/driver + template: |- + func(driverName string, driver driver.Driver) { + sql.Register(driverName, driver) + sqltrace.Register(driverName, driver) + }({{ index .AST.Args 0 }}, {{ index .AST.Args 1 }}) + +syntheticReferences: + database/sql/driver: true # shadowed import + gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql: true + +code: |- + package test + + import ( + "database/sql" + "database/sql/driver" + ) + + var conn driver.Connector + + func main() { + var driver string // shadowing import + sql.Register("foo", nil) + + db1, err := sql.Open("foo", "bar") + if err != nil { + panic(err) + } + defer db1.Close() + + println(driver) + + db2 := sql.OpenDB(conn) + defer db2.Close() + } diff --git a/internal/injector/testdata/injector/import-shadowing/expected.diff b/internal/injector/testdata/injector/import-shadowing/expected.diff new file mode 100644 index 00000000..14cc8cc3 --- /dev/null +++ b/internal/injector/testdata/injector/import-shadowing/expected.diff @@ -0,0 +1,31 @@ +--- input.go ++++ output.go +@@ -1,15 +1,25 @@ ++//line input.go:1:1 + package test + + import ( + "database/sql" +- "database/sql/driver" ++//line :1 ++ __orchestrion_driver "database/sql/driver" ++ __orchestrion_sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" + ) + +-var conn driver.Connector ++//line input.go:8 ++var conn __orchestrion_driver.Connector + + func main() { + var driver string // shadowing import +- sql.Register("foo", nil) ++//line :1 ++ func(driverName string, driver __orchestrion_driver.Driver) { ++ sql.Register(driverName, driver) ++ __orchestrion_sqltrace.Register(driverName, driver) ++ }( ++//line input.go:12 ++ "foo", nil) + + db1, err := sql.Open("foo", "bar") + if err != nil { diff --git a/internal/injector/typed/refmap.go b/internal/injector/typed/refmap.go index 99067388..62180795 100644 --- a/internal/injector/typed/refmap.go +++ b/internal/injector/typed/refmap.go @@ -7,7 +7,9 @@ package typed import ( "fmt" + "go/ast" "go/token" + "go/types" "slices" "strings" @@ -23,6 +25,8 @@ type ( ReferenceMap struct { refs map[string]ReferenceKind aliases map[string]string + nodeMap map[dst.Node]ast.Node + scopes map[ast.Node]*types.Scope } ) @@ -33,44 +37,100 @@ const ( RelocationTarget ReferenceKind = false ) -// AddImport determines whether a new import declaration needs to be added to make the provided path -// available within the specified file. Returns true if that is the case. False if the import path -// is already available within the file. -func (r *ReferenceMap) AddImport(file *dst.File, path string, alias string) bool { - if hasImport(file, path) { +func NewReferenceMap(nodeMap map[dst.Node]ast.Node, scopes map[ast.Node]*types.Scope) ReferenceMap { + return ReferenceMap{nodeMap: nodeMap, scopes: scopes} +} + +// AddImport takes a package import path and the name in file and the result of a recursive parent lookup. +// It first determines if the import is already present +// and if it has not been shadowed by a local declaration. If both conditions are met, the import is added to the +// reference map and the function returns true. Otherwise, it returns false. +func (r *ReferenceMap) AddImport(file *dst.File, nodes []dst.Node, path string, localName string) bool { + if len(nodes) == 0 { + panic("nodeChain must not be empty") + } + + // If the import is already present, has a meaningful alias or no alias, + // and is accessible from the current scope, we don't need to do anything. + prevLocalName, ok := hasImport(file, path) + if ok && prevLocalName != "." && prevLocalName != "_" && r.isImportInScope(nodes, path, localName) { return false } // Register in this ReferenceMap r.add(path, ImportStatement) - if alias != "_" { + if localName != "_" { // We don't register blank aliases, as this is the default behavior anyway... if r.aliases == nil { r.aliases = make(map[string]string) } - r.aliases[path] = fmt.Sprintf("__orchestrion_%s", alias) + r.aliases[path] = fmt.Sprintf("__orchestrion_%s", localName) } return true } -func hasImport(file *dst.File, path string) bool { +// isImportInScope checks if the provided name is an import in the scope of the provided node +func (r *ReferenceMap) isImportInScope(nodes []dst.Node, path string, name string) bool { + if len(nodes) == 0 { + panic("nodes must not be empty") + } + + var ( + scope *types.Scope + pos = r.nodeMap[nodes[0]].Pos() + ) + for i := 0; i < len(nodes) && scope == nil; i++ { + node := nodes[i] + if funcDecl, ok := node.(*dst.FuncDecl); ok { + // Somehow scopes are not attached to FuncDecl nodes, so we need to look at the type ¯\_(シ)_/¯ + node = funcDecl.Type + } + + astNode, ok := r.nodeMap[node] + if !ok { + continue + } + + scope = r.scopes[astNode] + } + + if scope == nil { + panic(fmt.Errorf("unable to find scope for node %T in parent chain", nodes[0])) + } + + _, obj := scope.LookupParent(name, pos) + if obj != nil { + if pkg, isImport := obj.(*types.PkgName); isImport { + return pkg.Imported().Path() == path + } + } + + return false +} + +// hasImport checks if the provided file already imports the provided path and its local name. +func hasImport(file *dst.File, path string) (string, bool) { for _, spec := range file.Imports { specPath, err := basiclit.String(spec.Path) if err != nil { continue } if specPath == path { - return true + name := "" + if spec.Name != nil { + name = spec.Name.Name + } + return name, true } } - return false + return "", false } // AddLink registers the provided path as a relocation target resolution source. If this path is // already registered as an import, this method does nothing and returns false. func (r *ReferenceMap) AddLink(file *dst.File, path string) bool { - if hasImport(file, path) { + if _, ok := hasImport(file, path); ok { return false }