Skip to content

Commit

Permalink
[TEST][TEDD] improve TEDD tests to also run on CPU Docker image. (#6643)
Browse files Browse the repository at this point in the history
* Amend regular expressions to match with what is being reported
   by CPU Docker image Graphviz
 * Fix typo on dependency checking function
 * Organise imports
  • Loading branch information
leandron authored Oct 9, 2020
1 parent 0922d17 commit f73a1f6
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/python/contrib/test_tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import te
import numpy as np
import re

from tvm import te
from tvm import topi


Expand All @@ -25,7 +25,7 @@ def findany(pattern, str):
assert len(matches) > 0, "Pattern not found.\nPattern: " + pattern + "\nString: " + str


def checkdepdency():
def checkdependency():
import pkg_resources

return not {"graphviz", "ipython"} - {pkg.key for pkg in pkg_resources.working_set}
Expand Down Expand Up @@ -55,7 +55,7 @@ def verify():
findany(r"Tensor_3_0 -> Stage_4:I_1", str)
findany(r"Stage_4:O_0 -> Tensor_4_0", str)

if checkdepdency():
if checkdependency():
verify()


Expand All @@ -79,16 +79,16 @@ def verify():
findany(r"subgraph cluster_Stage_0", str)
findany(r"subgraph cluster_Stage_1", str)
# Check itervars and their types
findany(r"i\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str)
findany(r"k\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str)
findany(r"\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str)
findany(r"\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str)
# Check the split node
findany(r"Split_Relation_1_0 +.+\>Split", str)
# Check all edges to/from the split node
findany(r"IterVar_1_1:itervar -> Split_Relation_1_0:Input", str)
findany(r"Split_Relation_1_0:Outer -> IterVar_1_2:itervar", str)
findany(r"Split_Relation_1_0:Inner -> IterVar_1_3:itervar", str)

if checkdepdency():
if checkdependency():
verify()


Expand Down Expand Up @@ -129,18 +129,18 @@ def verify():
# and compute
findany(
r"Stage_1.*A\.shared<br/>Scope: shared.+>0.+>"
r"ax0\(kDataPar\).+>1.+ax1\(kDataPar\).+>2.+>ax2\(kDataPar\).+>"
r"\[A\(ax0, ax1, ax2\)\]",
r"ax0.*\(kDataPar\).+>1.+ax1.*\(kDataPar\).+>2.+>ax2.*\(kDataPar\).+>"
r"\[A[\[\(]ax0, ax1, ax2[\)\]]\]",
str,
)
# Check itervars of types different from KDataPar
findany(r"bk\(kVectorized\)", str)
findany(r"r.outer\(kCommReduce\)", str)
findany(r"bk.*\(kVectorized\)", str)
findany(r"r.outer.*\(kCommReduce\)", str)
findany(r"label=ROOT", str)
# Check the compute_at edge
findany(r"Stage_1.*\[color\=\"\#000000\"\]", str)

if checkdepdency():
if checkdependency():
verify()


Expand Down

0 comments on commit f73a1f6

Please sign in to comment.