diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go index ff3987473770..64841b38a035 100644 --- a/rlp/rlpgen/gen.go +++ b/rlp/rlpgen/gen.go @@ -24,6 +24,7 @@ import ( "sort" "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct" + "golang.org/x/tools/go/packages" ) // buildContext keeps the data needed for make*Op. @@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { // file and assigns unique names of temporary variables. type genContext struct { inPackage *types.Package - imports map[string]struct{} + imports map[string]genImportPackage tempCounter int } +type genImportPackage struct { + alias string + pkg *types.Package +} + func newGenContext(inPackage *types.Package) *genContext { return &genContext{ - inPackage: inPackage, - imports: make(map[string]struct{}), + inPackage: inPackage, + imports: make(map[string]genImportPackage), + tempCounter: 0, } } @@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() { ctx.tempCounter = 0 } -func (ctx *genContext) addImport(path string) { - if path == ctx.inPackage.Path() { - return // avoid importing the package that we're generating in. +func (ctx *genContext) addImportPath(path string) { + pkg, err := ctx.loadPackage(path) + if err != nil { + panic(fmt.Sprintf("can't load package %q: %v", path, err)) } - // TODO: renaming? - ctx.imports[path] = struct{}{} + ctx.addImport(pkg) } -// importsList returns all packages that need to be imported. -func (ctx *genContext) importsList() []string { - imp := make([]string, 0, len(ctx.imports)) - for k := range ctx.imports { - imp = append(imp, k) +func (ctx *genContext) addImport(pkg *types.Package) string { + if pkg.Path() == ctx.inPackage.Path() { + return "" // avoid importing the package that we're generating in } - sort.Strings(imp) - return imp + if p, exists := ctx.imports[pkg.Path()]; exists { + return p.alias + } + var ( + baseName = pkg.Name() + alias = baseName + counter = 1 + ) + // If the base name conflicts with an existing import, add a numeric suffix. + for ctx.hasAlias(alias) { + alias = fmt.Sprintf("%s%d", baseName, counter) + counter++ + } + ctx.imports[pkg.Path()] = genImportPackage{alias, pkg} + return alias +} + +// hasAlias checks if an alias is already in use +func (ctx *genContext) hasAlias(alias string) bool { + for _, p := range ctx.imports { + if p.alias == alias { + return true + } + } + return false } -// qualify is the types.Qualifier used for printing types. +// loadPackage attempts to load package information +func (ctx *genContext) loadPackage(path string) (*types.Package, error) { + cfg := &packages.Config{Mode: packages.NeedName} + pkgs, err := packages.Load(cfg, path) + if err != nil { + return nil, err + } + if len(pkgs) == 0 { + return nil, fmt.Errorf("no package found for path %s", path) + } + return types.NewPackage(path, pkgs[0].Name), nil +} + +// qualify is the types.Qualifier used for printing types func (ctx *genContext) qualify(pkg *types.Package) string { if pkg.Path() == ctx.inPackage.Path() { return "" } - ctx.addImport(pkg.Path()) - // TODO: renaming? - return pkg.Name() + return ctx.addImport(pkg) +} + +// importsList returns all packages that need to be imported +func (ctx *genContext) importsList() []string { + imp := make([]string, 0, len(ctx.imports)) + for path, p := range ctx.imports { + if p.alias == p.pkg.Name() { + imp = append(imp, fmt.Sprintf("%q", path)) + } else { + imp = append(imp, fmt.Sprintf("%s %q", p.alias, path)) + } + } + sort.Strings(imp) + return imp } type op interface { @@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string { } func (op uint256Op) genDecode(ctx *genContext) (string, string) { - ctx.addImport("github.com/holiman/uint256") + ctx.addImportPath("github.com/holiman/uint256") var b bytes.Buffer resultV := ctx.temp() @@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru // generateDecoder generates the DecodeRLP method on 'typ'. func generateDecoder(ctx *genContext, typ string, op op) []byte { ctx.resetTemp() - ctx.addImport(pathOfPackageRLP) + ctx.addImportPath(pathOfPackageRLP) result, code := op.genDecode(ctx) var b bytes.Buffer @@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte { // generateEncoder generates the EncodeRLP method on 'typ'. func generateEncoder(ctx *genContext, typ string, op op) []byte { ctx.resetTemp() - ctx.addImport("io") - ctx.addImport(pathOfPackageRLP) + ctx.addImportPath("io") + ctx.addImportPath(pathOfPackageRLP) var b bytes.Buffer fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) @@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b var b bytes.Buffer fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) for _, imp := range ctx.importsList() { - fmt.Fprintf(&b, "import %q\n", imp) + fmt.Fprintf(&b, "import %s\n", imp) } if encoder { fmt.Fprintln(&b) diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go index b4fabb3dc633..4bfb1b9d255f 100644 --- a/rlp/rlpgen/gen_test.go +++ b/rlp/rlpgen/gen_test.go @@ -47,7 +47,7 @@ func init() { } } -var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} +var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256", "pkgclash"} func TestOutput(t *testing.T) { for _, test := range tests { diff --git a/rlp/rlpgen/testdata/pkgclash.in.txt b/rlp/rlpgen/testdata/pkgclash.in.txt new file mode 100644 index 000000000000..1d407881ce41 --- /dev/null +++ b/rlp/rlpgen/testdata/pkgclash.in.txt @@ -0,0 +1,13 @@ +// -*- mode: go -*- + +package test + +import ( + eth1 "github.com/ethereum/go-ethereum/eth" + eth2 "github.com/ethereum/go-ethereum/eth/protocols/eth" +) + +type Test struct { + A eth1.MinerAPI + B eth2.GetReceiptsPacket +} diff --git a/rlp/rlpgen/testdata/pkgclash.out.txt b/rlp/rlpgen/testdata/pkgclash.out.txt new file mode 100644 index 000000000000..d119639b9978 --- /dev/null +++ b/rlp/rlpgen/testdata/pkgclash.out.txt @@ -0,0 +1,82 @@ +package test + +import "github.com/ethereum/go-ethereum/common" +import "github.com/ethereum/go-ethereum/eth" +import "github.com/ethereum/go-ethereum/rlp" +import "io" +import eth1 "github.com/ethereum/go-ethereum/eth/protocols/eth" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + _tmp1 := w.List() + w.ListEnd(_tmp1) + _tmp2 := w.List() + w.WriteUint64(obj.B.RequestId) + _tmp3 := w.List() + for _, _tmp4 := range obj.B.GetReceiptsRequest { + w.WriteBytes(_tmp4[:]) + } + w.ListEnd(_tmp3) + w.ListEnd(_tmp2) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // A: + var _tmp1 eth.MinerAPI + { + if _, err := dec.List(); err != nil { + return err + } + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.A = _tmp1 + // B: + var _tmp2 eth1.GetReceiptsPacket + { + if _, err := dec.List(); err != nil { + return err + } + // RequestId: + _tmp3, err := dec.Uint64() + if err != nil { + return err + } + _tmp2.RequestId = _tmp3 + // GetReceiptsRequest: + var _tmp4 []common.Hash + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + var _tmp5 common.Hash + if err := dec.ReadBytes(_tmp5[:]); err != nil { + return err + } + _tmp4 = append(_tmp4, _tmp5) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp2.GetReceiptsRequest = _tmp4 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.B = _tmp2 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +}