diff --git a/src/cmd/cover/cover.go b/src/cmd/cover/cover.go index 5bea3b11aa5978..d74a0f18801b90 100644 --- a/src/cmd/cover/cover.go +++ b/src/cmd/cover/cover.go @@ -10,17 +10,15 @@ import ( "fmt" "go/ast" "go/parser" - "go/printer" "go/token" "io" "io/ioutil" "log" "os" - "path/filepath" "sort" "strconv" - "strings" + "cmd/internal/edit" "cmd/internal/objabi" ) @@ -61,7 +59,7 @@ var ( var profile string // The profile to read; the value of -html or -func -var counterStmt func(*File, ast.Expr) ast.Stmt +var counterStmt func(*File, string) string const ( atomicPackagePath = "sync/atomic" @@ -154,11 +152,48 @@ type Block struct { // File is a wrapper for the state of a file used in the parser. // The basic parse tree walker is a method of this type. type File struct { - fset *token.FileSet - name string // Name of file. - astFile *ast.File - blocks []Block - atomicPkg string // Package name for "sync/atomic" in this file. + fset *token.FileSet + name string // Name of file. + astFile *ast.File + blocks []Block + content []byte + edit *edit.Buffer +} + +// findText finds text in the original source, starting at pos. +// It correctly skips over comments and assumes it need not +// handle quoted strings. +// It returns a byte offset within f.src. +func (f *File) findText(pos token.Pos, text string) int { + b := []byte(text) + start := f.offset(pos) + i := start + s := f.content + for i < len(s) { + if bytes.HasPrefix(s[i:], b) { + return i + } + if i+2 <= len(s) && s[i] == '/' && s[i+1] == '/' { + for i < len(s) && s[i] != '\n' { + i++ + } + continue + } + if i+2 <= len(s) && s[i] == '/' && s[i+1] == '*' { + for i += 2; ; i++ { + if i+2 > len(s) { + return 0 + } + if s[i] == '*' && s[i+1] == '/' { + i += 2 + break + } + } + continue + } + i++ + } + return -1 } // Visit implements the ast.Visitor interface. @@ -171,18 +206,18 @@ func (f *File) Visit(node ast.Node) ast.Visitor { case *ast.CaseClause: // switch for _, n := range n.List { clause := n.(*ast.CaseClause) - clause.Body = f.addCounters(clause.Colon+1, clause.End(), clause.Body, false) + f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false) } return f case *ast.CommClause: // select for _, n := range n.List { clause := n.(*ast.CommClause) - clause.Body = f.addCounters(clause.Colon+1, clause.End(), clause.Body, false) + f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false) } return f } } - n.List = f.addCounters(n.Lbrace, n.Rbrace+1, n.List, true) // +1 to step past closing brace. + f.addCounters(n.Lbrace, n.Lbrace+1, n.Rbrace+1, n.List, true) // +1 to step past closing brace. case *ast.IfStmt: if n.Init != nil { ast.Walk(f, n.Init) @@ -203,6 +238,13 @@ func (f *File) Visit(node ast.Node) ast.Visitor { // if y { // } // } + f.edit.Insert(f.offset(n.Body.End()), "else{") + elseOffset := f.findText(n.Body.End(), "else") + if elseOffset < 0 { + panic("lost else") + } + f.edit.Delete(elseOffset, elseOffset+4) + f.edit.Insert(f.offset(n.Else.End()), "}") switch stmt := n.Else.(type) { case *ast.IfStmt: block := &ast.BlockStmt{ @@ -247,91 +289,6 @@ func (f *File) Visit(node ast.Node) ast.Visitor { return f } -// fixDirectives identifies //go: comments (known as directives or pragmas) and -// attaches them to the documentation group of the following top-level -// definitions. All other comments are dropped since non-documentation comments -// tend to get printed in the wrong place after the AST is modified. -// -// fixDirectives returns a list of unhandled directives. These comments could -// not be attached to a top-level declaration and should be be printed at the -// end of the file. -func (f *File) fixDirectives() []*ast.Comment { - // Scan comments in the file and collect directives. Detach all comments. - var directives []*ast.Comment - prev := token.NoPos // smaller than any valid token.Pos - for _, cg := range f.astFile.Comments { - for _, c := range cg.List { - // Skip directives that will be included by initialComments, i.e., those - // before the package declaration but not in the file doc comment group. - if f.isDirective(c) && (c.Pos() >= f.astFile.Package || cg == f.astFile.Doc) { - // Assume (but verify) that comments are sorted by position. - pos := c.Pos() - if !pos.IsValid() { - log.Fatalf("compiler directive has no position: %q", c.Text) - } else if pos < prev { - log.Fatalf("compiler directives are out of order. %s was before %s.", - f.fset.Position(prev), f.fset.Position(pos)) - } - prev = pos - - directives = append(directives, c) - } - } - cg.List = nil - } - f.astFile.Comments = nil // force printer to use node comments - if len(directives) == 0 { - // Common case: no directives to attach. - return nil - } - - // Iterate over top-level declarations and attach preceding directives. - di := 0 - prev = token.NoPos - for _, decl := range f.astFile.Decls { - // Assume (but verify) that declarations are sorted by position. - pos := decl.Pos() - if !pos.IsValid() { - // Synthetic decl. Don't add directives. - continue - } - if pos < prev { - log.Fatalf("declarations are out of order. %s was before %s.", - f.fset.Position(prev), f.fset.Position(pos)) - } - prev = pos - - var doc **ast.CommentGroup - switch d := decl.(type) { - case *ast.FuncDecl: - doc = &d.Doc - case *ast.GenDecl: - // Limitation: for grouped declarations, we attach directives to the decl, - // not individual specs. Directives must start in the first column, so - // they are lost when the group is indented. - doc = &d.Doc - default: - // *ast.BadDecls is the only other type we might see, but - // we don't need to handle it here. - continue - } - - for di < len(directives) && directives[di].Pos() < pos { - c := directives[di] - if *doc == nil { - *doc = new(ast.CommentGroup) - } - c.Slash = pos - 1 // must be strictly less than pos - (*doc).List = append((*doc).List, c) - di++ - } - } - - // Return trailing directives. These cannot apply to a specific declaration - // and may be printed at the end of the file. - return directives[di:] -} - // unquote returns the unquoted string. func unquote(s string) string { t, err := strconv.Unquote(s) @@ -341,91 +298,8 @@ func unquote(s string) string { return t } -// addImport adds an import for the specified path, if one does not already exist, and returns -// the local package name. -func (f *File) addImport(path string) string { - // Does the package already import it? - for _, s := range f.astFile.Imports { - if unquote(s.Path.Value) == path { - if s.Name != nil { - return s.Name.Name - } - return filepath.Base(path) - } - } - newImport := &ast.ImportSpec{ - Name: ast.NewIdent(atomicPackageName), - Path: &ast.BasicLit{ - Kind: token.STRING, - Value: fmt.Sprintf("%q", path), - }, - } - impDecl := &ast.GenDecl{ - Tok: token.IMPORT, - Specs: []ast.Spec{ - newImport, - }, - } - // Make the new import the first Decl in the file. - astFile := f.astFile - astFile.Decls = append(astFile.Decls, nil) - copy(astFile.Decls[1:], astFile.Decls[0:]) - astFile.Decls[0] = impDecl - astFile.Imports = append(astFile.Imports, newImport) - - // Now refer to the package, just in case it ends up unused. - // That is, append to the end of the file the declaration - // var _ = _cover_atomic_.AddUint32 - reference := &ast.GenDecl{ - Tok: token.VAR, - Specs: []ast.Spec{ - &ast.ValueSpec{ - Names: []*ast.Ident{ - ast.NewIdent("_"), - }, - Values: []ast.Expr{ - &ast.SelectorExpr{ - X: ast.NewIdent(atomicPackageName), - Sel: ast.NewIdent("AddUint32"), - }, - }, - }, - }, - } - astFile.Decls = append(astFile.Decls, reference) - return atomicPackageName -} - var slashslash = []byte("//") -// initialComments returns the prefix of content containing only -// whitespace and line comments. Any +build directives must appear -// within this region. This approach is more reliable than using -// go/printer to print a modified AST containing comments. -// -func initialComments(content []byte) []byte { - // Derived from go/build.Context.shouldBuild. - end := 0 - p := content - for len(p) > 0 { - line := p - if i := bytes.IndexByte(line, '\n'); i >= 0 { - line, p = line[:i], p[i+1:] - } else { - p = p[len(p):] - } - line = bytes.TrimSpace(line) - if len(line) == 0 { // Blank line. - end = len(content) - len(p) - continue - } - if !bytes.HasPrefix(line, slashslash) { // Not comment line. - break - } - } - return content[:end] -} - func annotate(name string) { fset := token.NewFileSet() content, err := ioutil.ReadFile(name) @@ -440,14 +314,23 @@ func annotate(name string) { file := &File{ fset: fset, name: name, + content: content, + edit: edit.NewBuffer(content), astFile: parsedFile, } if *mode == "atomic" { - file.atomicPkg = file.addImport(atomicPackagePath) + // Add import of sync/atomic immediately after package clause. + // We do this even if there is an existing import, because the + // existing import may be shadowed at any given place we want + // to refer to it, and our name (_cover_atomic_) is less likely to + // be shadowed. + file.edit.Insert(file.offset(file.astFile.Name.End()), + fmt.Sprintf("; import %s %q", atomicPackageName, atomicPackagePath)) } - unhandledDirectives := file.fixDirectives() ast.Walk(file, file.astFile) + newContent := file.edit.Bytes() + fd := os.Stdout if *output != "" { var err error @@ -456,89 +339,32 @@ func annotate(name string) { log.Fatalf("cover: %s", err) } } - fd.Write(initialComments(content)) // Retain '// +build' directives. - file.print(fd) + fd.Write(newContent) + // After printing the source tree, add some declarations for the counters etc. // We could do this by adding to the tree, but it's easier just to print the text. file.addVariables(fd) - - // Print directives that are not processed in fixDirectives. Some - // directives are not attached to anything in the tree hence will not - // be printed by printer. So, we have to explicitly print them here. - for _, c := range unhandledDirectives { - fmt.Fprintln(fd, c.Text) - } -} - -func (f *File) print(w io.Writer) { - printer.Fprint(w, f.fset, f.astFile) -} - -// isDirective reports whether a comment is a compiler directive. -func (f *File) isDirective(c *ast.Comment) bool { - return strings.HasPrefix(c.Text, "//go:") && f.fset.Position(c.Slash).Column == 1 -} - -// intLiteral returns an ast.BasicLit representing the integer value. -func (f *File) intLiteral(i int) *ast.BasicLit { - node := &ast.BasicLit{ - Kind: token.INT, - Value: fmt.Sprint(i), - } - return node -} - -// index returns an ast.BasicLit representing the number of counters present. -func (f *File) index() *ast.BasicLit { - return f.intLiteral(len(f.blocks)) } // setCounterStmt returns the expression: __count[23] = 1. -func setCounterStmt(f *File, counter ast.Expr) ast.Stmt { - return &ast.AssignStmt{ - Lhs: []ast.Expr{counter}, - Tok: token.ASSIGN, - Rhs: []ast.Expr{f.intLiteral(1)}, - } +func setCounterStmt(f *File, counter string) string { + return fmt.Sprintf("%s = 1", counter) } // incCounterStmt returns the expression: __count[23]++. -func incCounterStmt(f *File, counter ast.Expr) ast.Stmt { - return &ast.IncDecStmt{ - X: counter, - Tok: token.INC, - } +func incCounterStmt(f *File, counter string) string { + return fmt.Sprintf("%s++", counter) } // atomicCounterStmt returns the expression: atomic.AddUint32(&__count[23], 1) -func atomicCounterStmt(f *File, counter ast.Expr) ast.Stmt { - return &ast.ExprStmt{ - X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(f.atomicPkg), - Sel: ast.NewIdent("AddUint32"), - }, - Args: []ast.Expr{&ast.UnaryExpr{ - Op: token.AND, - X: counter, - }, - f.intLiteral(1), - }, - }, - } +func atomicCounterStmt(f *File, counter string) string { + return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter) } // newCounter creates a new counter expression of the appropriate form. -func (f *File) newCounter(start, end token.Pos, numStmt int) ast.Stmt { - counter := &ast.IndexExpr{ - X: &ast.SelectorExpr{ - X: ast.NewIdent(*varVar), - Sel: ast.NewIdent("Count"), - }, - Index: f.index(), - } - stmt := counterStmt(f, counter) +func (f *File) newCounter(start, end token.Pos, numStmt int) string { + stmt := counterStmt(f, fmt.Sprintf("%s.Count[%d]", *varVar, len(f.blocks))) f.blocks = append(f.blocks, Block{start, end, numStmt}) return stmt } @@ -555,15 +381,15 @@ func (f *File) newCounter(start, end token.Pos, numStmt int) ast.Stmt { // counters will be added before S1 and before S3. The block containing S2 // will be visited in a separate call. // TODO: Nested simple blocks get unnecessary (but correct) counters -func (f *File) addCounters(pos, blockEnd token.Pos, list []ast.Stmt, extendToClosingBrace bool) []ast.Stmt { +func (f *File) addCounters(pos, insertPos, blockEnd token.Pos, list []ast.Stmt, extendToClosingBrace bool) { // Special case: make sure we add a counter to an empty block. Can't do this below // or we will add a counter to an empty statement list after, say, a return statement. if len(list) == 0 { - return []ast.Stmt{f.newCounter(pos, blockEnd, 0)} + f.edit.Insert(f.offset(insertPos), f.newCounter(insertPos, blockEnd, 0)+";") + return } // We have a block (statement list), but it may have several basic blocks due to the // appearance of statements that affect the flow of control. - var newList []ast.Stmt for { // Find first statement that affects flow of control (break, continue, if, etc.). // It will be the last statement of this basic block. @@ -606,16 +432,15 @@ func (f *File) addCounters(pos, blockEnd token.Pos, list []ast.Stmt, extendToClo end = blockEnd } if pos != end { // Can have no source to cover if e.g. blocks abut. - newList = append(newList, f.newCounter(pos, end, last)) + f.edit.Insert(f.offset(insertPos), f.newCounter(pos, end, last)+";") } - newList = append(newList, list[0:last]...) list = list[last:] if len(list) == 0 { break } pos = list[0].Pos() + insertPos = pos } - return newList } // hasFuncLiteral reports the existence and position of the first func literal @@ -850,4 +675,10 @@ func (f *File) addVariables(w io.Writer) { // Close the struct initialization. fmt.Fprintf(w, "}\n") + + // Emit a reference to the atomic package to avoid + // import and not used error when there's no code in a file. + if *mode == "atomic" { + fmt.Fprintf(w, "var _ = %s.LoadUint32\n", atomicPackageName) + } } diff --git a/src/cmd/cover/cover_test.go b/src/cmd/cover/cover_test.go index 4d8826b96db89a..79ddf4f4652a2a 100644 --- a/src/cmd/cover/cover_test.go +++ b/src/cmd/cover/cover_test.go @@ -6,6 +6,7 @@ package main_test import ( "bytes" + "flag" "fmt" "go/ast" "go/parser" @@ -37,7 +38,7 @@ var ( coverProfile = filepath.Join(testdata, "profile.cov") ) -var debug = false // Keeps the rewritten files around if set. +var debug = flag.Bool("debug", false, "keep rewritten files for debugging") // Run this shell script, but do it in Go so it can be run by "go test". // @@ -63,7 +64,7 @@ func TestCover(t *testing.T) { } // defer removal of test_line.go - if !debug { + if !*debug { defer os.Remove(coverInput) } @@ -79,7 +80,7 @@ func TestCover(t *testing.T) { run(cmd, t) // defer removal of ./testdata/test_cover.go - if !debug { + if !*debug { defer os.Remove(coverOutput) } @@ -100,10 +101,10 @@ func TestCover(t *testing.T) { t.Error("'go:linkname' compiler directive not found") } - // No other comments should be present in generated code. - c := ".*// This comment shouldn't appear in generated go code.*" - if got, err := regexp.MatchString(c, string(file)); err != nil || got { - t.Errorf("non compiler directive comment %q found", c) + // Other comments should be preserved too. + c := ".*// This comment didn't appear in generated go code.*" + if got, err := regexp.MatchString(c, string(file)); err != nil || !got { + t.Errorf("non compiler directive comment %q not found", c) } } diff --git a/src/cmd/cover/testdata/test.go b/src/cmd/cover/testdata/test.go index 5effa2d7e905fa..0b03ef91ab559a 100644 --- a/src/cmd/cover/testdata/test.go +++ b/src/cmd/cover/testdata/test.go @@ -282,7 +282,7 @@ loop: } } -// This comment shouldn't appear in generated go code. +// This comment didn't appear in generated go code. func haha() { // Needed for cover to add counter increment here. _ = 42 diff --git a/src/cmd/internal/edit/edit.go b/src/cmd/internal/edit/edit.go new file mode 100644 index 00000000000000..2d470f4c8a9ec1 --- /dev/null +++ b/src/cmd/internal/edit/edit.go @@ -0,0 +1,93 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package edit implements buffered position-based editing of byte slices. +package edit + +import ( + "fmt" + "sort" +) + +// A Buffer is a queue of edits to apply to a given byte slice. +type Buffer struct { + old []byte + q edits +} + +// An edit records a single text modification: change the bytes in [start,end) to new. +type edit struct { + start int + end int + new string +} + +// An edits is a list of edits that is sortable by start offset, breaking ties by end offset. +type edits []edit + +func (x edits) Len() int { return len(x) } +func (x edits) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x edits) Less(i, j int) bool { + if x[i].start != x[j].start { + return x[i].start < x[j].start + } + return x[i].end < x[j].end +} + +// NewBuffer returns a new buffer to accumulate changes to an initial data slice. +// The returned buffer maintains a reference to the data, so the caller must ensure +// the data is not modified until after the Buffer is done being used. +func NewBuffer(data []byte) *Buffer { + return &Buffer{old: data} +} + +func (b *Buffer) Insert(pos int, new string) { + if pos < 0 || pos > len(b.old) { + panic("invalid edit position") + } + b.q = append(b.q, edit{pos, pos, new}) +} + +func (b *Buffer) Delete(start, end int) { + if end < start || start < 0 || end > len(b.old) { + panic("invalid edit position") + } + b.q = append(b.q, edit{start, end, ""}) +} + +func (b *Buffer) Replace(start, end int, new string) { + if end < start || start < 0 || end > len(b.old) { + panic("invalid edit position") + } + b.q = append(b.q, edit{start, end, new}) +} + +// Bytes returns a new byte slice containing the original data +// with the queued edits applied. +func (b *Buffer) Bytes() []byte { + // Sort edits by starting position and then by ending position. + // Breaking ties by ending position allows insertions at point x + // to be applied before a replacement of the text at [x, y). + sort.Stable(b.q) + + var new []byte + offset := 0 + for i, e := range b.q { + if e.start < offset { + e0 := b.q[i-1] + panic(fmt.Sprintf("overlapping edits: [%d,%d)->%q, [%d,%d)->%q", e0.start, e0.end, e0.new, e.start, e.end, e.new)) + } + new = append(new, b.old[offset:e.start]...) + offset = e.end + new = append(new, e.new...) + } + new = append(new, b.old[offset:]...) + return new +} + +// String returns a string containing the original data +// with the queued edits applied. +func (b *Buffer) String() string { + return string(b.Bytes()) +} diff --git a/src/cmd/internal/edit/edit_test.go b/src/cmd/internal/edit/edit_test.go new file mode 100644 index 00000000000000..0e0c564d987046 --- /dev/null +++ b/src/cmd/internal/edit/edit_test.go @@ -0,0 +1,28 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package edit + +import "testing" + +func TestEdit(t *testing.T) { + b := NewBuffer([]byte("0123456789")) + b.Insert(8, ",7½,") + b.Replace(9, 10, "the-end") + b.Insert(10, "!") + b.Insert(4, "3.14,") + b.Insert(4, "π,") + b.Insert(4, "3.15,") + b.Replace(3, 4, "three,") + want := "012three,3.14,π,3.15,4567,7½,8the-end!" + + s := b.String() + if s != want { + t.Errorf("b.String() = %q, want %q", s, want) + } + sb := b.Bytes() + if string(sb) != want { + t.Errorf("b.Bytes() = %q, want %q", sb, want) + } +}