diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index defb0612144c..f1422bf28213 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -226,7 +226,8 @@ def __init__(self, axes, policy, **kwargs): def _generate_space(self, now, tmp_stack, enforce_no_tail=False): """Generate space by DFS""" if now == self.num_output - 1: - if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0: + prod = np.prod(tmp_stack, dtype=np.int64) + if self.product % prod == 0 or (not enforce_no_tail and prod < self.product): self.entities.append(SplitEntity([-1] + tmp_stack[::-1])) else: for factor in self.factors: diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py index 1da3fb0182ba..85d572412f9e 100644 --- a/tests/python/unittest/test_autotvm_space.py +++ b/tests/python/unittest/test_autotvm_space.py @@ -42,6 +42,26 @@ def test_split(): assert len(cfg) == 64 assert len(cfg.space_map['tile_y']) == 8 + # test policy + cfg = ConfigSpace() + cfg.define_split('tile_x', cfg.axis(256), policy='factors', num_outputs=3) + assert len(cfg.space_map['tile_x']) == 45 + + cfg.define_split('tile_y', cfg.axis(256), policy='power2', num_outputs=3) + assert len(cfg.space_map['tile_y']) == 45 + + cfg.define_split('tile_z', cfg.axis(256), policy='verbose', num_outputs=3) + assert len(cfg.space_map['tile_z']) == 45 + + cfg.define_split('tile_a', cfg.axis(224), policy='factors', num_outputs=3) + assert len(cfg.space_map['tile_a']) == 63 + + cfg.define_split('tile_b', cfg.axis(224), policy='power2', num_outputs=3) + assert len(cfg.space_map['tile_b']) == 36 + + cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3) + assert len(cfg.space_map['tile_c']) == 84 + # test fallback cfg = FallbackConfigEntity() cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)