@@ -22,7 +22,6 @@ import (
2222 "go/format"
2323 "go/types"
2424 "sort"
25- "strings"
2625
2726 "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
2827 "golang.org/x/tools/go/packages"
@@ -98,14 +97,19 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
9897// file and assigns unique names of temporary variables.
9998type genContext struct {
10099 inPackage * types.Package
101- imports map [string ]string // Changed from map[string]struct{} to map[string]string to store package aliases
100+ imports map [string ]genImportPackage
102101 tempCounter int
103102}
104103
104+ type genImportPackage struct {
105+ alias string
106+ pkg * types.Package
107+ }
108+
105109func newGenContext (inPackage * types.Package ) * genContext {
106110 return & genContext {
107111 inPackage : inPackage ,
108- imports : make (map [string ]string ),
112+ imports : make (map [string ]genImportPackage ),
109113 tempCounter : 0 ,
110114 }
111115}
@@ -120,42 +124,39 @@ func (ctx *genContext) resetTemp() {
120124 ctx .tempCounter = 0
121125}
122126
123- func (ctx * genContext ) addImport (path string ) string {
124- if path == ctx .inPackage .Path () {
125- return "" // avoid importing the package that we're generating in
126- }
127-
128- // Check if we already have an alias for this package
129- if alias , exists := ctx .imports [path ]; exists {
130- return alias
131- }
132-
133- // Get the package name and check for conflicts
127+ func (ctx * genContext ) addImportPath (path string ) {
134128 pkg , err := ctx .loadPackage (path )
135129 if err != nil {
136- // If we can't load the package, use the last component of the path
137- parts := strings .Split (path , "/" )
138- pkg = types .NewPackage (path , parts [len (parts )- 1 ])
130+ panic (fmt .Sprintf ("can't load package %q: %v" , path , err ))
139131 }
132+ ctx .addImport (pkg )
133+ }
140134
135+ func (ctx * genContext ) addImport (pkg * types.Package ) string {
136+ if pkg .Path () == ctx .inPackage .Path () {
137+ return "" // avoid importing the package that we're generating in
138+ }
139+ if p , exists := ctx .imports [pkg .Path ()]; exists {
140+ return p .alias
141+ }
141142 baseName := pkg .Name ()
142143 alias := baseName
143144 counter := 1
144-
145+
145146 // If the base name conflicts with any existing import, add a numeric suffix
146147 for ctx .hasAlias (alias ) {
147148 alias = fmt .Sprintf ("%s%d" , baseName , counter )
148149 counter ++
149150 }
150-
151- ctx .imports [path ] = alias
151+
152+ ctx .imports [pkg . Path () ] = genImportPackage { alias , pkg }
152153 return alias
153154}
154155
155156// hasAlias checks if an alias is already in use
156157func (ctx * genContext ) hasAlias (alias string ) bool {
157- for _ , existingAlias := range ctx .imports {
158- if existingAlias == alias {
158+ for _ , p := range ctx .imports {
159+ if p . alias == alias {
159160 return true
160161 }
161162 }
@@ -180,19 +181,19 @@ func (ctx *genContext) qualify(pkg *types.Package) string {
180181 if pkg .Path () == ctx .inPackage .Path () {
181182 return ""
182183 }
183- return ctx .addImport (pkg . Path () )
184+ return ctx .addImport (pkg )
184185}
185186
186187// importsList returns all packages that need to be imported
187188func (ctx * genContext ) importsList () []string {
188189 imp := make ([]string , 0 , len (ctx .imports ))
189- for path , alias := range ctx .imports {
190- if alias == pkg .Name () {
190+ for path , p := range ctx .imports {
191+ if p . alias == p . pkg .Name () {
191192 // If the alias matches the package name, use standard import
192193 imp = append (imp , fmt .Sprintf ("%q" , path ))
193194 } else {
194195 // If we have a custom alias, use aliased import
195- imp = append (imp , fmt .Sprintf ("%s %q" , alias , path ))
196+ imp = append (imp , fmt .Sprintf ("%s %q" , p . alias , path ))
196197 }
197198 }
198199 sort .Strings (imp )
@@ -413,7 +414,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
413414}
414415
415416func (op uint256Op ) genDecode (ctx * genContext ) (string , string ) {
416- ctx .addImport ("github.com/holiman/uint256" )
417+ ctx .addImportPath ("github.com/holiman/uint256" )
417418
418419 var b bytes.Buffer
419420 resultV := ctx .temp ()
@@ -786,7 +787,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
786787// generateDecoder generates the DecodeRLP method on 'typ'.
787788func generateDecoder (ctx * genContext , typ string , op op ) []byte {
788789 ctx .resetTemp ()
789- ctx .addImport (pathOfPackageRLP )
790+ ctx .addImportPath (pathOfPackageRLP )
790791
791792 result , code := op .genDecode (ctx )
792793 var b bytes.Buffer
@@ -801,8 +802,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
801802// generateEncoder generates the EncodeRLP method on 'typ'.
802803func generateEncoder (ctx * genContext , typ string , op op ) []byte {
803804 ctx .resetTemp ()
804- ctx .addImport ("io" )
805- ctx .addImport (pathOfPackageRLP )
805+ ctx .addImportPath ("io" )
806+ ctx .addImportPath (pathOfPackageRLP )
806807
807808 var b bytes.Buffer
808809 fmt .Fprintf (& b , "func (obj *%s) EncodeRLP(_w io.Writer) error {\n " , typ )
@@ -837,7 +838,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
837838 var b bytes.Buffer
838839 fmt .Fprintf (& b , "package %s\n \n " , pkg .Name ())
839840 for _ , imp := range ctx .importsList () {
840- fmt .Fprintf (& b , "import %q \n " , imp )
841+ fmt .Fprintf (& b , "import %s \n " , imp )
841842 }
842843 if encoder {
843844 fmt .Fprintln (& b )
0 commit comments