@@ -419,6 +419,8 @@ func (v *astModsVisitor) Visit(node ast.Node) ast.Visitor {
419
419
v .modifyFuncInit (x )
420
420
} else if x .Name .Name == "RegisterRouters" {
421
421
v .modifyFuncRegisterRouters (x )
422
+ } else if x .Name .Name == "Release" {
423
+ v .modifyFuncRelease (x )
422
424
}
423
425
}
424
426
return v
@@ -630,3 +632,56 @@ func (v *astModsVisitor) modifyFuncRegisterRouters(x *ast.FuncDecl) {
630
632
}
631
633
}
632
634
}
635
+
636
+ func (v * astModsVisitor ) modifyFuncRelease (x * ast.FuncDecl ) {
637
+ findIndex := - 1
638
+ list := x .Body .List
639
+ for i , stmt := range list {
640
+ if s , ok := stmt .(* ast.IfStmt ); ok {
641
+ var sb strings.Builder
642
+ printer .Fprint (& sb , v .fset , s .Init )
643
+ if strings .Contains (sb .String (), fmt .Sprintf ("%s.Release" , v .args .ModuleName )) {
644
+ findIndex = i
645
+ break
646
+ }
647
+ }
648
+ }
649
+ if v .args .Flag & AstFlagGen != 0 {
650
+ if findIndex == - 1 {
651
+ e , err := parser .ParseExpr (fmt .Sprintf ("a.%s.Release(ctx)" , v .args .ModuleName ))
652
+ if err == nil {
653
+ list = append (list [:len (list )- 1 ], append ([]ast.Stmt {& ast.IfStmt {
654
+ Init : & ast.AssignStmt {
655
+ Lhs : []ast.Expr {
656
+ ast .NewIdent ("err" ),
657
+ },
658
+ Tok : token .DEFINE ,
659
+ Rhs : []ast.Expr {
660
+ e ,
661
+ },
662
+ },
663
+ Cond : & ast.BinaryExpr {
664
+ X : ast .NewIdent ("err" ),
665
+ Op : token .NEQ ,
666
+ Y : ast .NewIdent ("nil" ),
667
+ },
668
+ Body : & ast.BlockStmt {
669
+ List : []ast.Stmt {
670
+ & ast.ReturnStmt {
671
+ Results : []ast.Expr {
672
+ ast .NewIdent ("err" ),
673
+ },
674
+ },
675
+ },
676
+ },
677
+ }}, list [len (list )- 1 ])... )
678
+ x .Body .List = list
679
+ }
680
+ }
681
+ } else if v .args .Flag & AstFlagRem != 0 {
682
+ if findIndex != - 1 {
683
+ list = append (list [:findIndex ], list [findIndex + 1 :]... )
684
+ x .Body .List = list
685
+ }
686
+ }
687
+ }
0 commit comments