diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py index 8c97214ea4bb..d6bb7622d640 100644 --- a/topi/python/topi/x86/injective.py +++ b/topi/python/topi/x86/injective.py @@ -45,6 +45,12 @@ def schedule_injective_from_existing(sch, out): sch[out].parallel(fused) elif len(sch[out].op.axis) >= 1: sch[out].parallel(sch[out].op.axis[0]) + + # Vectorize the inner most for loop. Tiling first to get a const extent + if len(sch[out].op.axis) >= 1: + l = sch[out].op.axis[-1] + _, li = sch[out].split(l, factor=16) + sch[out].vectorize(li) return sch @generic.schedule_injective.register(["cpu"])