Skip to content

Commit e9242f8

Browse files
committed
rlp/rlpgen: fix
1 parent e1c23fa commit e9242f8

File tree

1 file changed

+32
-31
lines changed

1 file changed

+32
-31
lines changed

rlp/rlpgen/gen.go

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"go/format"
2323
"go/types"
2424
"sort"
25-
"strings"
2625

2726
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
2827
"golang.org/x/tools/go/packages"
@@ -98,14 +97,19 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
9897
// file and assigns unique names of temporary variables.
9998
type genContext struct {
10099
inPackage *types.Package
101-
imports map[string]string // Changed from map[string]struct{} to map[string]string to store package aliases
100+
imports map[string]genImportPackage
102101
tempCounter int
103102
}
104103

104+
type genImportPackage struct {
105+
alias string
106+
pkg *types.Package
107+
}
108+
105109
func newGenContext(inPackage *types.Package) *genContext {
106110
return &genContext{
107111
inPackage: inPackage,
108-
imports: make(map[string]string),
112+
imports: make(map[string]genImportPackage),
109113
tempCounter: 0,
110114
}
111115
}
@@ -120,42 +124,39 @@ func (ctx *genContext) resetTemp() {
120124
ctx.tempCounter = 0
121125
}
122126

123-
func (ctx *genContext) addImport(path string) string {
124-
if path == ctx.inPackage.Path() {
125-
return "" // avoid importing the package that we're generating in
126-
}
127-
128-
// Check if we already have an alias for this package
129-
if alias, exists := ctx.imports[path]; exists {
130-
return alias
131-
}
132-
133-
// Get the package name and check for conflicts
127+
func (ctx *genContext) addImportPath(path string) {
134128
pkg, err := ctx.loadPackage(path)
135129
if err != nil {
136-
// If we can't load the package, use the last component of the path
137-
parts := strings.Split(path, "/")
138-
pkg = types.NewPackage(path, parts[len(parts)-1])
130+
panic(fmt.Sprintf("can't load package %q: %v", path, err))
139131
}
132+
ctx.addImport(pkg)
133+
}
140134

135+
func (ctx *genContext) addImport(pkg *types.Package) string {
136+
if pkg.Path() == ctx.inPackage.Path() {
137+
return "" // avoid importing the package that we're generating in
138+
}
139+
if p, exists := ctx.imports[pkg.Path()]; exists {
140+
return p.alias
141+
}
141142
baseName := pkg.Name()
142143
alias := baseName
143144
counter := 1
144-
145+
145146
// If the base name conflicts with any existing import, add a numeric suffix
146147
for ctx.hasAlias(alias) {
147148
alias = fmt.Sprintf("%s%d", baseName, counter)
148149
counter++
149150
}
150-
151-
ctx.imports[path] = alias
151+
152+
ctx.imports[pkg.Path()] = genImportPackage{alias, pkg}
152153
return alias
153154
}
154155

155156
// hasAlias checks if an alias is already in use
156157
func (ctx *genContext) hasAlias(alias string) bool {
157-
for _, existingAlias := range ctx.imports {
158-
if existingAlias == alias {
158+
for _, p := range ctx.imports {
159+
if p.alias == alias {
159160
return true
160161
}
161162
}
@@ -180,19 +181,19 @@ func (ctx *genContext) qualify(pkg *types.Package) string {
180181
if pkg.Path() == ctx.inPackage.Path() {
181182
return ""
182183
}
183-
return ctx.addImport(pkg.Path())
184+
return ctx.addImport(pkg)
184185
}
185186

186187
// importsList returns all packages that need to be imported
187188
func (ctx *genContext) importsList() []string {
188189
imp := make([]string, 0, len(ctx.imports))
189-
for path, alias := range ctx.imports {
190-
if alias == pkg.Name() {
190+
for path, p := range ctx.imports {
191+
if p.alias == p.pkg.Name() {
191192
// If the alias matches the package name, use standard import
192193
imp = append(imp, fmt.Sprintf("%q", path))
193194
} else {
194195
// If we have a custom alias, use aliased import
195-
imp = append(imp, fmt.Sprintf("%s %q", alias, path))
196+
imp = append(imp, fmt.Sprintf("%s %q", p.alias, path))
196197
}
197198
}
198199
sort.Strings(imp)
@@ -413,7 +414,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
413414
}
414415

415416
func (op uint256Op) genDecode(ctx *genContext) (string, string) {
416-
ctx.addImport("github.com/holiman/uint256")
417+
ctx.addImportPath("github.com/holiman/uint256")
417418

418419
var b bytes.Buffer
419420
resultV := ctx.temp()
@@ -786,7 +787,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
786787
// generateDecoder generates the DecodeRLP method on 'typ'.
787788
func generateDecoder(ctx *genContext, typ string, op op) []byte {
788789
ctx.resetTemp()
789-
ctx.addImport(pathOfPackageRLP)
790+
ctx.addImportPath(pathOfPackageRLP)
790791

791792
result, code := op.genDecode(ctx)
792793
var b bytes.Buffer
@@ -801,8 +802,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
801802
// generateEncoder generates the EncodeRLP method on 'typ'.
802803
func generateEncoder(ctx *genContext, typ string, op op) []byte {
803804
ctx.resetTemp()
804-
ctx.addImport("io")
805-
ctx.addImport(pathOfPackageRLP)
805+
ctx.addImportPath("io")
806+
ctx.addImportPath(pathOfPackageRLP)
806807

807808
var b bytes.Buffer
808809
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
@@ -837,7 +838,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
837838
var b bytes.Buffer
838839
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
839840
for _, imp := range ctx.importsList() {
840-
fmt.Fprintf(&b, "import %q\n", imp)
841+
fmt.Fprintf(&b, "import %s\n", imp)
841842
}
842843
if encoder {
843844
fmt.Fprintln(&b)

0 commit comments

Comments
 (0)