@@ -2,6 +2,7 @@ package paralleltest
22
33import  (
44	"go/ast" 
5+ 	"go/types" 
56	"strings" 
67
78	"golang.org/x/tools/go/analysis" 
@@ -34,9 +35,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
3435		funcDecl  :=  node .(* ast.FuncDecl )
3536		var  funcHasParallelMethod ,
3637			rangeStatementOverTestCasesExists ,
37- 			rangeStatementHasParallelMethod ,
38- 			testLoopVariableReinitialised  bool 
39- 		var  testRunLoopIdentifier  string 
38+ 			rangeStatementHasParallelMethod  bool 
39+ 		var  loopVariableUsedInRun  * string 
4040		var  numberOfTestRun  int 
4141		var  positionOfTestRunNode  []ast.Node 
4242		var  rangeNode  ast.Node 
@@ -81,6 +81,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
8181			case  * ast.RangeStmt :
8282				rangeNode  =  v 
8383
84+ 				var  loopVars  []types.Object 
85+ 				for  _ , expr  :=  range  []ast.Expr {v .Key , v .Value } {
86+ 					if  id , ok  :=  expr .(* ast.Ident ); ok  {
87+ 						loopVars  =  append (loopVars , pass .TypesInfo .ObjectOf (id ))
88+ 					}
89+ 				}
90+ 
8491				ast .Inspect (v , func (n  ast.Node ) bool  {
8592					// nolint: gocritic 
8693					switch  r  :=  n .(type ) {
@@ -90,26 +97,20 @@ func run(pass *analysis.Pass) (interface{}, error) {
9097							innerTestVar  :=  getRunCallbackParameterName (r .X )
9198
9299							rangeStatementOverTestCasesExists  =  true 
93- 							testRunLoopIdentifier  =  methodRunFirstArgumentObjectName (r .X )
94100
95101							if  ! rangeStatementHasParallelMethod  {
96102								rangeStatementHasParallelMethod  =  methodParallelIsCalledInMethodRun (r .X , innerTestVar )
97103							}
104+ 
105+ 							if  loopVariableUsedInRun  ==  nil  {
106+ 								if  run , ok  :=  r .X .(* ast.CallExpr ); ok  {
107+ 									loopVariableUsedInRun  =  isLoopVarReferencedInRun (run , loopVars , pass .TypesInfo )
108+ 								}
109+ 							}
98110						}
99111					}
100112					return  true 
101113				})
102- 
103- 				// Check for the range loop value identifier re assignment 
104- 				// More info here https://gist.github.com/kunwardeep/80c2e9f3d3256c894898bae82d9f75d0 
105- 				if  rangeStatementOverTestCasesExists  {
106- 					var  rangeValueIdentifier  string 
107- 					if  i , ok  :=  v .Value .(* ast.Ident ); ok  {
108- 						rangeValueIdentifier  =  i .Name 
109- 					}
110- 
111- 					testLoopVariableReinitialised  =  testCaseLoopVariableReinitialised (v .Body .List , rangeValueIdentifier , testRunLoopIdentifier )
112- 				}
113114			}
114115		}
115116
@@ -120,12 +121,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
120121		if  rangeStatementOverTestCasesExists  &&  rangeNode  !=  nil  {
121122			if  ! rangeStatementHasParallelMethod  {
122123				pass .Reportf (rangeNode .Pos (), "Range statement for test %s missing the call to method parallel in test Run\n " , funcDecl .Name .Name )
123- 			} else  {
124- 				if  testRunLoopIdentifier  ==  ""  {
125- 					pass .Reportf (rangeNode .Pos (), "Range statement for test %s does not use range value in test Run\n " , funcDecl .Name .Name )
126- 				} else  if  ! testLoopVariableReinitialised  {
127- 					pass .Reportf (rangeNode .Pos (), "Range statement for test %s does not reinitialise the variable %s\n " , funcDecl .Name .Name , testRunLoopIdentifier )
128- 				}
124+ 			} else  if  loopVariableUsedInRun  !=  nil  {
125+ 				pass .Reportf (rangeNode .Pos (), "Range statement for test %s does not reinitialise the variable %s\n " , funcDecl .Name .Name , * loopVariableUsedInRun )
129126			}
130127		}
131128
@@ -140,38 +137,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
140137	return  nil , nil 
141138}
142139
143- func  testCaseLoopVariableReinitialised (statements  []ast.Stmt , rangeValueIdentifier  string , testRunLoopIdentifier  string ) bool  {
144- 	if  len (statements ) >  1  {
145- 		for  _ , s  :=  range  statements  {
146- 			leftIdentifier , rightIdentifier  :=  getLeftAndRightIdentifier (s )
147- 			if  leftIdentifier  ==  testRunLoopIdentifier  &&  rightIdentifier  ==  rangeValueIdentifier  {
148- 				return  true 
149- 			}
150- 		}
151- 	}
152- 	return  false 
153- }
154- 
155- // Return the left hand side and the right hand side identifiers name 
156- func  getLeftAndRightIdentifier (s  ast.Stmt ) (string , string ) {
157- 	var  leftIdentifier , rightIdentifier  string 
158- 	// nolint: gocritic 
159- 	switch  v  :=  s .(type ) {
160- 	case  * ast.AssignStmt :
161- 		if  len (v .Rhs ) ==  1  {
162- 			if  i , ok  :=  v .Rhs [0 ].(* ast.Ident ); ok  {
163- 				rightIdentifier  =  i .Name 
164- 			}
165- 		}
166- 		if  len (v .Lhs ) ==  1  {
167- 			if  i , ok  :=  v .Lhs [0 ].(* ast.Ident ); ok  {
168- 				leftIdentifier  =  i .Name 
169- 			}
170- 		}
171- 	}
172- 	return  leftIdentifier , rightIdentifier 
173- }
174- 
175140func  methodParallelIsCalledInMethodRun (node  ast.Node , testVar  string ) bool  {
176141	var  methodParallelCalled  bool 
177142	// nolint: gocritic 
@@ -247,22 +212,6 @@ func getRunCallbackParameterName(node ast.Node) string {
247212	return  "" 
248213}
249214
250- // Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T) 
251- func  methodRunFirstArgumentObjectName (node  ast.Node ) string  {
252- 	// nolint: gocritic 
253- 	switch  n  :=  node .(type ) {
254- 	case  * ast.CallExpr :
255- 		for  _ , arg  :=  range  n .Args  {
256- 			if  s , ok  :=  arg .(* ast.SelectorExpr ); ok  {
257- 				if  i , ok  :=  s .X .(* ast.Ident ); ok  {
258- 					return  i .Name 
259- 				}
260- 			}
261- 		}
262- 	}
263- 	return  "" 
264- }
265- 
266215// Checks if the function has the param type *testing.T; if it does, then the 
267216// parameter name is returned, too. 
268217func  isTestFunction (funcDecl  * ast.FuncDecl ) (bool , string ) {
@@ -291,3 +240,24 @@ func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
291240
292241	return  false , "" 
293242}
243+ 
244+ func  isLoopVarReferencedInRun (call  * ast.CallExpr , vars  []types.Object , typeInfo  * types.Info ) (found  * string ) {
245+ 	if  len (call .Args ) !=  2  {
246+ 		return 
247+ 	}
248+ 
249+ 	ast .Inspect (call .Args [1 ], func (n  ast.Node ) bool  {
250+ 		ident , ok  :=  n .(* ast.Ident )
251+ 		if  ! ok  {
252+ 			return  true 
253+ 		}
254+ 		for  _ , o  :=  range  vars  {
255+ 			if  typeInfo .ObjectOf (ident ) ==  o  {
256+ 				found  =  & ident .Name 
257+ 			}
258+ 		}
259+ 		return  true 
260+ 	})
261+ 
262+ 	return 
263+ }
0 commit comments