Skip to content

Commit

Permalink
Fix split's last factor issue (apache#4044)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and vinx13 committed Oct 1, 2019
1 parent 2f1edb9 commit 2d53762
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def __init__(self, axes, policy, **kwargs):
def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
"""Generate space by DFS"""
if now == self.num_output - 1:
if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0:
prod = np.prod(tmp_stack, dtype=np.int64)
if self.product % prod == 0 or (not enforce_no_tail and prod < self.product):
self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
else:
for factor in self.factors:
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_autotvm_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ def test_split():
assert len(cfg) == 64
assert len(cfg.space_map['tile_y']) == 8

# test policy
cfg = ConfigSpace()
cfg.define_split('tile_x', cfg.axis(256), policy='factors', num_outputs=3)
assert len(cfg.space_map['tile_x']) == 45

cfg.define_split('tile_y', cfg.axis(256), policy='power2', num_outputs=3)
assert len(cfg.space_map['tile_y']) == 45

cfg.define_split('tile_z', cfg.axis(256), policy='verbose', num_outputs=3)
assert len(cfg.space_map['tile_z']) == 45

cfg.define_split('tile_a', cfg.axis(224), policy='factors', num_outputs=3)
assert len(cfg.space_map['tile_a']) == 63

cfg.define_split('tile_b', cfg.axis(224), policy='power2', num_outputs=3)
assert len(cfg.space_map['tile_b']) == 36

cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3)
assert len(cfg.space_map['tile_c']) == 84

# test fallback
cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
Expand Down

0 comments on commit 2d53762

Please sign in to comment.