From caddfd8f1a6550696d15284d94d7dc73875303dc Mon Sep 17 00:00:00 2001 From: Aofei Sheng Date: Tue, 11 Mar 2025 08:49:27 +0800 Subject: [PATCH] refactor(ast): improve `Walk` with latest `go/ast.Walk` - Replaced multiple type-specific `walkXList` helper funcs with a single generic `walkList[N Node]` func. - Added `TypeParams` handling for `FuncType` and `TypeSpec`. - Improved nil-safety with additional checks. Signed-off-by: Aofei Sheng --- ast/walk.go | 94 ++++++++++++++++++++--------------------------------- 1 file changed, 36 insertions(+), 58 deletions(-) diff --git a/ast/walk.go b/ast/walk.go index 38cfc3fcc..b3287fc30 100644 --- a/ast/walk.go +++ b/ast/walk.go @@ -27,29 +27,9 @@ type Visitor interface { Visit(node Node) (w Visitor) } -// Helper functions for common node lists. They may be empty. - -func walkIdentList(v Visitor, list []*Ident) { - for _, x := range list { - Walk(v, x) - } -} - -func walkExprList(v Visitor, list []Expr) { - for _, x := range list { - Walk(v, x) - } -} - -func walkStmtList(v Visitor, list []Stmt) { - for _, x := range list { - Walk(v, x) - } -} - -func walkDeclList(v Visitor, list []Decl) { - for _, x := range list { - Walk(v, x) +func walkList[N Node](v Visitor, list []N) { + for _, node := range list { + Walk(v, node) } } @@ -75,16 +55,16 @@ func Walk(v Visitor, node Node) { // nothing to do case *CommentGroup: - for _, c := range n.List { - Walk(v, c) - } + walkList(v, n.List) case *Field: if n.Doc != nil { Walk(v, n.Doc) } - walkIdentList(v, n.Names) - Walk(v, n.Type) + walkList(v, n.Names) + if n.Type != nil { + Walk(v, n.Type) + } if n.Tag != nil { Walk(v, n.Tag) } @@ -93,9 +73,7 @@ func Walk(v Visitor, node Node) { } case *FieldList: - for _, f := range n.List { - Walk(v, f) - } + walkList(v, n.List) // Expressions case *BadExpr, *Ident, *NumberUnitLit: @@ -133,7 +111,7 @@ func Walk(v Visitor, node Node) { if n.Type != nil { Walk(v, n.Type) } - walkExprList(v, n.Elts) + walkList(v, n.Elts) case *ParenExpr: Walk(v, n.X) @@ -148,7 +126,7 @@ func Walk(v Visitor, node Node) { case *IndexListExpr: Walk(v, n.X) - walkExprList(v, n.Indices) + walkList(v, n.Indices) case *SliceExpr: Walk(v, n.X) @@ -170,7 +148,7 @@ func Walk(v Visitor, node Node) { case *CallExpr: Walk(v, n.Fun) - walkExprList(v, n.Args) + walkList(v, n.Args) case *StarExpr: Walk(v, n.X) @@ -197,6 +175,9 @@ func Walk(v Visitor, node Node) { Walk(v, n.Fields) case *FuncType: + if n.TypeParams != nil { + Walk(v, n.TypeParams) + } if n.Params != nil { Walk(v, n.Params) } @@ -233,16 +214,14 @@ func Walk(v Visitor, node Node) { case *SendStmt: Walk(v, n.Chan) - for _, val := range n.Values { - Walk(v, val) - } + walkList(v, n.Values) case *IncDecStmt: Walk(v, n.X) case *AssignStmt: - walkExprList(v, n.Lhs) - walkExprList(v, n.Rhs) + walkList(v, n.Lhs) + walkList(v, n.Rhs) case *GoStmt: Walk(v, n.Call) @@ -251,7 +230,7 @@ func Walk(v Visitor, node Node) { Walk(v, n.Call) case *ReturnStmt: - walkExprList(v, n.Results) + walkList(v, n.Results) case *BranchStmt: if n.Label != nil { @@ -259,7 +238,7 @@ func Walk(v Visitor, node Node) { } case *BlockStmt: - walkStmtList(v, n.List) + walkList(v, n.List) case *IfStmt: if n.Init != nil { @@ -272,8 +251,8 @@ func Walk(v Visitor, node Node) { } case *CaseClause: - walkExprList(v, n.List) - walkStmtList(v, n.Body) + walkList(v, n.List) + walkList(v, n.Body) case *SwitchStmt: if n.Init != nil { @@ -295,7 +274,7 @@ func Walk(v Visitor, node Node) { if n.Comm != nil { Walk(v, n.Comm) } - walkStmtList(v, n.Body) + walkList(v, n.Body) case *SelectStmt: Walk(v, n.Body) @@ -339,11 +318,11 @@ func Walk(v Visitor, node Node) { if n.Doc != nil { Walk(v, n.Doc) } - walkIdentList(v, n.Names) + walkList(v, n.Names) if n.Type != nil { Walk(v, n.Type) } - walkExprList(v, n.Values) + walkList(v, n.Values) if n.Comment != nil { Walk(v, n.Comment) } @@ -353,6 +332,9 @@ func Walk(v Visitor, node Node) { Walk(v, n.Doc) } Walk(v, n.Name) + if n.TypeParams != nil { + Walk(v, n.TypeParams) + } Walk(v, n.Type) if n.Comment != nil { Walk(v, n.Comment) @@ -365,9 +347,7 @@ func Walk(v Visitor, node Node) { if n.Doc != nil { Walk(v, n.Doc) } - for _, s := range n.Specs { - Walk(v, s) - } + walkList(v, n.Specs) case *FuncDecl: if !n.Shadow { // not a shadow entry @@ -392,7 +372,7 @@ func Walk(v Visitor, node Node) { if !n.NoPkgDecl { Walk(v, n.Name) } - walkDeclList(v, n.Decls) + walkList(v, n.Decls) // don't walk n.Comments - they have been // visited already through the individual // nodes @@ -404,14 +384,14 @@ func Walk(v Visitor, node Node) { // Go+ extended expr and stmt case *SliceLit: - walkExprList(v, n.Elts) + walkList(v, n.Elts) case *LambdaExpr: - walkIdentList(v, n.Lhs) - walkExprList(v, n.Rhs) + walkList(v, n.Lhs) + walkList(v, n.Rhs) case *LambdaExpr2: - walkIdentList(v, n.Lhs) + walkList(v, n.Lhs) Walk(v, n.Body) case *ForPhrase: @@ -433,9 +413,7 @@ func Walk(v Visitor, node Node) { if n.Elt != nil { Walk(v, n.Elt) } - for _, x := range n.Fors { - Walk(v, x) - } + walkList(v, n.Fors) case *ForPhraseStmt: Walk(v, n.ForPhrase) @@ -466,7 +444,7 @@ func Walk(v Visitor, node Node) { Walk(v, n.Recv) } Walk(v, n.Name) - walkExprList(v, n.Funcs) + walkList(v, n.Funcs) case *EnvExpr: Walk(v, n.Name)