From a1c85dc5aebce41c94aaa21b13720b31312ae47d Mon Sep 17 00:00:00 2001 From: Wei Pan Date: Fri, 24 Jan 2020 13:13:30 -0800 Subject: [PATCH] [AUTOTVM] Fix a bug in generating the search space - Do not use numpy.prod which ignores integer (64 bits) overflows. This leads to an incorrect number of points in the search space. --- python/tvm/autotvm/task/space.py | 4 +++- tests/python/unittest/test_autotvm_space.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index f1422bf28213..d83a248c4ece 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -226,7 +226,9 @@ def __init__(self, axes, policy, **kwargs): def _generate_space(self, now, tmp_stack, enforce_no_tail=False): """Generate space by DFS""" if now == self.num_output - 1: - prod = np.prod(tmp_stack, dtype=np.int64) + prod = functools.reduce(lambda x, y: x * y, tmp_stack) + if prod > self.product: + return if self.product % prod == 0 or (not enforce_no_tail and prod < self.product): self.entities.append(SplitEntity([-1] + tmp_stack[::-1])) else: diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py index 85d572412f9e..95f3201c5eb4 100644 --- a/tests/python/unittest/test_autotvm_space.py +++ b/tests/python/unittest/test_autotvm_space.py @@ -62,6 +62,21 @@ def test_split(): cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3) assert len(cfg.space_map['tile_c']) == 84 + # Count the number of non-negative integer solutions of a + b + c + d = n + def count4(n): + cnt = 0 + for a in range(0, n + 1): + for b in range(0, n - a + 1): + cnt += n - a - b + 1 + return cnt + + # test overflow + n = 25 + cfg = ConfigSpace() + cfg.define_split('x', cfg.axis(2**n), policy='factors', num_outputs=4) + # count4(25) is 3276. + assert len(cfg.space_map['x']) == count4(n) + # test fallback cfg = FallbackConfigEntity() cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)