Skip to content

Commit

Permalink
refactor: use custom loader instead of go/packages
Browse files Browse the repository at this point in the history
Having a custom loader lets us reuse parsing and typechecking
results from snapshot to snapshot, which speeds up large scripts
considerably. Before, we were writing out the files between commands
to work around overlay bugs, but that in turn was making the go
command recompile things after every command. Even without
the recompile, not invoking the go command at all after every
script command will be a win.

The time for "git generate" in CL XXX drops from 206s to 101s
from not needing to do a snap.Write after each command, and
it drops further to 35s from the cached loading. There's more to
be improved -- too much is being loaded now.

The custom loader also makes it easier for us to introduce new
packages during rewrites, because we have a full copy of the
package graph that can be updated.

The custom loader also lets us more easily fix a few bugs where
renaming was not finding all identifiers, especially in test packages.
  • Loading branch information
rsc committed Nov 21, 2020
1 parent 56468fb commit 8880f26
Show file tree
Hide file tree
Showing 18 changed files with 1,960 additions and 1,416 deletions.
16 changes: 8 additions & 8 deletions add.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"go/ast"
"go/token"

"golang.org/x/tools/go/packages"
"rsc.io/rf/refactor"
)

func cmdAdd(snap *refactor.Snapshot, args string) (more []string, exp bool) {
func cmdAdd(snap *refactor.Snapshot, args string) {
item, expr, text := snap.LookupNext(args)
if expr == "" {
snap.ErrorAt(token.NoPos, "usage: add address text...\n")
Expand All @@ -30,6 +29,9 @@ func cmdAdd(snap *refactor.Snapshot, args string) (more []string, exp bool) {

case refactor.ItemConst, refactor.ItemFunc, refactor.ItemType, refactor.ItemVar, refactor.ItemField:
stack := snap.SyntaxAt(item.Obj.Pos())
if len(stack) == 0 {
panic("LOST " + item.Name)
}
after := stack[1]
switch after.(type) {
case *ast.ValueSpec, *ast.TypeSpec:
Expand All @@ -41,28 +43,26 @@ func cmdAdd(snap *refactor.Snapshot, args string) (more []string, exp bool) {
_, pos = nodeRange(snap, after)

case refactor.ItemFile:
srcPkg := snap.Targets()[0] // TODO
srcFile := findFile(snap, srcPkg, item.Name)
_, srcFile := snap.FileByName(item.Name)
_, pos = snap.FileRange(srcFile.Package)

case refactor.ItemDir:
var dstPkg *packages.Package
var dstPkg *refactor.Package
for _, pkg := range snap.Packages() {
if pkg.PkgPath == item.Name {
dstPkg = pkg
break
}
}
if dstPkg == nil {
return []string{item.Name}, false
return
}
_, pos = snap.FileRange(dstPkg.Syntax[0].Package)
_, pos = snap.FileRange(dstPkg.Files[0].Syntax.Pos())

case refactor.ItemPos:
pos = item.End
}

// TODO: Is final \n a good idea?
snap.InsertAt(pos, text+"\n")
return nil, false
}
40 changes: 18 additions & 22 deletions ex.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"strconv"
"strings"

"golang.org/x/tools/go/packages"
"rsc.io/rf/edit"
"rsc.io/rf/refactor"
)
Expand All @@ -24,15 +23,12 @@ type example struct {
new ast.Node
}

func cmdEx(snap *refactor.Snapshot, text string) (more []string, exp bool) {
more, code, err := parseEx(snap, text)
func cmdEx(snap *refactor.Snapshot, text string) {
code, err := parseEx(snap, text)
if err != nil {
snap.ErrorAt(token.NoPos, "ex: %v", err)
return
}
if len(more) > 0 {
return
}

if _, err := checkEx(snap, code); err != nil {
snap.ErrorAt(token.NoPos, "ex: %v", err)
Expand All @@ -47,11 +43,11 @@ func cutGo(text, sep string) (before, after string, ok bool, err error) {
return before, after, ok, nil
}

func parseEx(snap *refactor.Snapshot, text string) (more []string, code string, err error) {
func parseEx(snap *refactor.Snapshot, text string) (code string, err error) {
fset := token.NewFileSet()
var buf bytes.Buffer
if len(snap.Targets()) == 1 {
fmt.Fprintf(&buf, "package %s\n", snap.Targets()[0].Types.Name())
if true { // TODO: single vs multiple targets
fmt.Fprintf(&buf, "package %s\n", snap.Target().Types.Name())
} else {
fmt.Fprintf(&buf, "package ex\n")
}
Expand All @@ -60,7 +56,7 @@ func parseEx(snap *refactor.Snapshot, text string) (more []string, code string,
for text != "" {
stmt, rest, _, err := cutGo(text, ";")
if err != nil {
return nil, "", err
return "", err
}
text = rest
stmt = strings.TrimSpace(stmt)
Expand All @@ -69,15 +65,15 @@ func parseEx(snap *refactor.Snapshot, text string) (more []string, code string,
}
switch kw := strings.Fields(stmt)[0]; kw {
case "package", "type", "func", "const":
return nil, "", fmt.Errorf("%s declaration not allowed", kw)
return "", fmt.Errorf("%s declaration not allowed", kw)

case "defer", "for", "go", "if", "return", "select", "switch":
return nil, "", fmt.Errorf("%s statement not allowed", kw)
return "", fmt.Errorf("%s statement not allowed", kw)

case "import":
file, err := parser.ParseFile(fset, "ex.go", "package p;"+stmt, 0)
if err != nil {
return nil, "", fmt.Errorf("parsing %s: %v", stmt, err)
return "", fmt.Errorf("parsing %s: %v", stmt, err)
}
imp := file.Imports[0]
pkg := importPath(imp)
Expand All @@ -88,16 +84,16 @@ func parseEx(snap *refactor.Snapshot, text string) (more []string, code string,
}
}
if !have {
more = append(more, pkg)
return "", fmt.Errorf("import %q not available", pkg)
}
if !importOK {
return nil, "", fmt.Errorf("parsing %s: import too late", stmt)
return "", fmt.Errorf("parsing %s: import too late", stmt)
}
fmt.Fprintf(&buf, "%s\n", stmt)

case "var":
if _, err := parser.ParseExpr("func() {" + stmt + "}"); err != nil {
return nil, "", fmt.Errorf("parsing %s: %v", stmt, err)
return "", fmt.Errorf("parsing %s: %v", stmt, err)
}
if importOK {
fmt.Fprintf(&buf, "func _() {\n")
Expand All @@ -111,7 +107,7 @@ func parseEx(snap *refactor.Snapshot, text string) (more []string, code string,
// because we already processed it once.
before, after, ok, _ := cutGo(stmt, "->")
if !ok {
return nil, "", fmt.Errorf("parsing: %s: missing -> in rewrite", stmt)
return "", fmt.Errorf("parsing: %s: missing -> in rewrite", stmt)
}
// TODO: parse stmt / parse expr
if importOK {
Expand All @@ -122,10 +118,10 @@ func parseEx(snap *refactor.Snapshot, text string) (more []string, code string,
}
}
if importOK {
return nil, "", fmt.Errorf("no example rewrites")
return "", fmt.Errorf("no example rewrites")
}
fmt.Fprintf(&buf, "}\n")
return more, buf.String(), nil
return buf.String(), nil
}

func checkEx(snap *refactor.Snapshot, code string) ([]example, error) {
Expand Down Expand Up @@ -157,8 +153,8 @@ func checkEx(snap *refactor.Snapshot, code string) ([]example, error) {

var info *types.Info
var typesPkg *types.Package
if len(snap.Targets()) == 1 {
p := snap.Targets()[0]
if true { // TODO single vs double
p := snap.Target()
info = p.TypesInfo
typesPkg = p.Types
} else {
Expand Down Expand Up @@ -233,7 +229,7 @@ func applyEx(snap *refactor.Snapshot, code string, codePos token.Pos, typesPkg *

var avoid map[ast.Node]bool

snap.ForEachTargetFile(func(target *packages.Package, file *ast.File) {
snap.ForEachTargetFile(func(target *refactor.Package, file *ast.File) {
refactor.Walk(file, func(stack []ast.Node) {
if m.match(pattern, stack[0]) {
// Do not apply substitution in its own definition.
Expand Down
9 changes: 3 additions & 6 deletions inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ import (
"go/types"
"strings"

"golang.org/x/tools/go/packages"
"rsc.io/rf/refactor"
)

func cmdInline(snap *refactor.Snapshot, args string) (more []string, exp bool) {
func cmdInline(snap *refactor.Snapshot, args string) {
args = strings.TrimLeft(args, " \t")
var rm map[types.Object]bool
if flag, rest, _ := cutAny(args, " \t"); flag == "-rm" {
Expand Down Expand Up @@ -67,13 +66,11 @@ func cmdInline(snap *refactor.Snapshot, args string) (more []string, exp bool) {
removeDecls(snap, rm)
}

// TODO: If the names are exported, should we inline elsewhere too?
// Probably. Certainly if they are being deleted.
return nil, false
// TODO: Should we inline in other packages?
}

func inlineValues(snap *refactor.Snapshot, fix map[types.Object]ast.Expr) {
snap.ForEachFile(func(pkg *packages.Package, file *ast.File) {
snap.ForEachFile(func(pkg *refactor.Package, file *ast.File) {
refactor.Walk(file, func(stack []ast.Node) {
id, ok := stack[0].(*ast.Ident)
if !ok {
Expand Down
9 changes: 2 additions & 7 deletions key.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ import (
"go/token"
"go/types"

"golang.org/x/tools/go/packages"
"rsc.io/rf/refactor"
)

func cmdKey(snap *refactor.Snapshot, args string) (more []string, exp bool) {
func cmdKey(snap *refactor.Snapshot, args string) {
items, _ := snap.LookupAll(args)
if len(items) == 0 {
snap.ErrorAt(token.NoPos, "usage: key StructType...")
Expand Down Expand Up @@ -41,10 +40,6 @@ func cmdKey(snap *refactor.Snapshot, args string) (more []string, exp bool) {
}

keyLiterals(snap, fixing)

// Any cross-package literal references should already be keyed,
// so no need to consider importers.
return nil, false
}

func keyLiterals(snap *refactor.Snapshot, list []types.Type) {
Expand All @@ -53,7 +48,7 @@ func keyLiterals(snap *refactor.Snapshot, list []types.Type) {
fixing[t] = true
}

snap.ForEachFile(func(pkg *packages.Package, file *ast.File) {
snap.ForEachFile(func(pkg *refactor.Package, file *ast.File) {
refactor.Walk(file, func(stack []ast.Node) {
lit, ok := stack[0].(*ast.CompositeLit)
if !ok || len(lit.Elts) == 0 || lit.Incomplete {
Expand Down
40 changes: 24 additions & 16 deletions mv.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"reflect"
"strings"

"golang.org/x/tools/go/packages"
"rsc.io/rf/refactor"
)

Expand Down Expand Up @@ -45,7 +44,7 @@ func inScope(name string, obj types.Object) posChecker {
}
}

func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
func cmdMv(snap *refactor.Snapshot, args string) {
items, _ := snap.LookupAll(args)
if len(items) < 2 {
snap.ErrorAt(token.NoPos, "usage: mv old... new")
Expand All @@ -64,7 +63,7 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {

srcs, dst := items[:len(items)-1], items[len(items)-1]
if dst.Kind == refactor.ItemDir || dst.Kind == refactor.ItemFile {
var dstPkg *packages.Package
var dstPkg *refactor.Package
if dst.Kind == refactor.ItemDir {
for _, pkg := range snap.Packages() {
if pkg.PkgPath == dst.Name {
Expand All @@ -73,11 +72,13 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
}
}
if dstPkg == nil {
return []string{dst.Name}, false
// TODO: Load on demand.
snap.ErrorAt(token.NoPos, "unknown package %s", dst.Name)
return
}
} else {
if !strings.Contains(dst.Name, "/") {
dstPkg = snap.Targets()[0] // TODO
dstPkg = snap.Target() // TODO
} else {
pkgPath := path.Dir(dst.Name)
for _, pkg := range snap.Packages() {
Expand All @@ -87,7 +88,9 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
}
}
if dstPkg == nil {
return []string{dst.Name}, false
// TODO: Load on demand.
snap.ErrorAt(token.NoPos, "unknown package %s", dst.Name)
return
}
dst.Name = path.Base(dst.Name)
}
Expand All @@ -110,7 +113,7 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
return
}

exp = mvCode(snap, srcs, dst, dstPkg)
mvCode(snap, srcs, dst, dstPkg)
return
}

Expand Down Expand Up @@ -155,7 +158,6 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
}
rewriteDefn(snap, old, newName)
rewriteUses(snap, old, newName, notInScope(newName))
exp = token.IsExported(old.Name)
return
}

Expand All @@ -167,8 +169,6 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
}
rewriteDefn(snap, old, newName)
rewriteUses(snap, old, newName, nil)
_, last, _ := cutLast(old.Name, ".")
exp = token.IsExported(last)
return
}

Expand Down Expand Up @@ -200,7 +200,6 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
}
removeDecl(snap, old)
rewriteUses(snap, old, newPath, inScope(newTop.Name, newTop.Obj))
exp = token.IsExported(old.Name)
return
}
}
Expand Down Expand Up @@ -228,8 +227,6 @@ func cmdMv(snap *refactor.Snapshot, args string) (more []string, exp bool) {
return
}
methodToFunc(snap, old.Obj.(*types.Func), newName)
_, last, _ := cutLast(old.Name, ".")
exp = token.IsExported(last)
return
}

Expand Down Expand Up @@ -265,7 +262,7 @@ func rewriteDefn(snap *refactor.Snapshot, old *refactor.Item, new string) {
}

func rewriteUses(snap *refactor.Snapshot, old *refactor.Item, new string, checkPos posChecker) {
snap.ForEachFile(func(pkg *packages.Package, file *ast.File) {
fix := func(pkg *refactor.Package, file *ast.File) {
refactor.Walk(file, func(stack []ast.Node) {
id, ok := stack[0].(*ast.Ident)
if !ok || pkg.TypesInfo.Uses[id] != old.Obj {
Expand All @@ -282,7 +279,18 @@ func rewriteUses(snap *refactor.Snapshot, old *refactor.Item, new string, checkP
}
snap.ReplaceNode(id, new)
})
})
}
// TODO: This should be something like
// snap.ForReverseDepsFiles
// and it should load the reverse deps on demand.
if !token.IsExported(old.Outermost().Name) {
pkg, _ := snap.FileAt(old.Obj.Pos())
for _, file := range pkg.Files {
fix(pkg, file.Syntax)
}
return
}
snap.ForEachFile(fix)
}

func StackTypes(list []ast.Node) string {
Expand Down Expand Up @@ -395,7 +403,7 @@ func methodToFunc(snap *refactor.Snapshot, method *types.Func, name string) {
recvType := sig.Recv().Type()
_, recvPtr := recvType.(*types.Pointer)

snap.ForEachFile(func(pkg *packages.Package, file *ast.File) {
snap.ForEachFile(func(pkg *refactor.Package, file *ast.File) {
refactor.Walk(file, func(stack []ast.Node) {
id, ok := stack[0].(*ast.Ident)
if !ok || pkg.TypesInfo.Uses[id] != method {
Expand Down
Loading

0 comments on commit 8880f26

Please sign in to comment.