diff --git a/python/tvm/contrib/util.py b/python/tvm/contrib/util.py index e980e5520802..8f6dfc7f28ec 100644 --- a/python/tvm/contrib/util.py +++ b/python/tvm/contrib/util.py @@ -16,20 +16,53 @@ # under the License. """Common system utilities""" import atexit +import contextlib +import datetime import os import tempfile +import threading import shutil try: import fcntl except ImportError: fcntl = None + +class DirectoryCreatedPastAtExit(Exception): + """Raised when a TempDirectory is created after the atexit hook runs.""" + class TempDirectory(object): """Helper object to manage temp directory during testing. Automatically removes the directory when it went out of scope. """ + # When True, all TempDirectory are *NOT* deleted and instead live inside a predicable directory + # tree. + _KEEP_FOR_DEBUG = False + + # In debug mode, each tempdir is named after the sequence + _NUM_TEMPDIR_CREATED = 0 + _NUM_TEMPDIR_CREATED_LOCK = threading.Lock() + @classmethod + def _increment_num_tempdir_created(cls): + with cls._NUM_TEMPDIR_CREATED_LOCK: + to_return = cls._NUM_TEMPDIR_CREATED + cls._NUM_TEMPDIR_CREATED += 1 + + return to_return + + _DEBUG_PARENT_DIR = None + @classmethod + def _get_debug_parent_dir(cls): + if cls._DEBUG_PARENT_DIR is None: + all_parents = f'{tempfile.gettempdir()}/tvm-debug-mode-tempdirs' + if not os.path.isdir(all_parents): + os.makedirs(all_parents) + cls._DEBUG_PARENT_DIR = tempfile.mkdtemp( + prefix=datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S___'), dir=all_parents) + return cls._DEBUG_PARENT_DIR + TEMPDIRS = set() @classmethod def remove_tempdirs(cls): @@ -42,20 +75,42 @@ def remove_tempdirs(cls): cls.TEMPDIRS = None + @classmethod + @contextlib.contextmanager + def set_keep_for_debug(cls, set_to=True): + """Keep temporary directories past program exit for debugging.""" + old_keep_for_debug = cls._KEEP_FOR_DEBUG + try: + cls._KEEP_FOR_DEBUG = set_to + yield + finally: + cls._KEEP_FOR_DEBUG = old_keep_for_debug + def __init__(self, custom_path=None): + if self.TEMPDIRS is None: + raise DirectoryCreatedPastAtExit() + + self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG if custom_path: os.mkdir(custom_path) self.temp_dir = custom_path else: - self.temp_dir = tempfile.mkdtemp() + if self._created_with_keep_for_debug: + parent_dir = self._get_debug_parent_dir() + self.temp_dir = f'{parent_dir}/{self._increment_num_tempdir_created():05d}' + os.mkdir(self.temp_dir) + else: + self.temp_dir = tempfile.mkdtemp() - self.TEMPDIRS.add(self.temp_dir) + if not self._created_with_keep_for_debug: + self.TEMPDIRS.add(self.temp_dir) def remove(self): """Remote the tmp dir""" if self.temp_dir: - shutil.rmtree(self.temp_dir, ignore_errors=True) - self.TEMPDIRS.remove(self.temp_dir) + if not self._created_with_keep_for_debug: + shutil.rmtree(self.temp_dir, ignore_errors=True) + self.TEMPDIRS.remove(self.temp_dir) self.temp_dir = None def __del__(self): diff --git a/tests/python/contrib/test_util.py b/tests/python/contrib/test_util.py new file mode 100644 index 000000000000..55a2b7616e84 --- /dev/null +++ b/tests/python/contrib/test_util.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for functions in tvm/python/tvm/contrib/util.py.""" + +import datetime +import os +import shutil +from tvm.contrib import util + + +def validate_debug_dir_path(temp_dir, expected_basename): + dirname, basename = os.path.split(temp_dir.temp_dir) + assert basename == expected_basename, 'unexpected basename: %s' % (basename,) + + parent_dir = os.path.basename(dirname) + create_time = datetime.datetime.strptime(parent_dir.split('___', 1)[0], '%Y-%m-%dT%H-%M-%S') + assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60) + + + +def test_tempdir(): + assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True" + + temp_dir = util.tempdir() + assert os.path.exists(temp_dir.temp_dir) + + old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG + try: + for temp_dir_number in range(0, 3): + with util.TempDirectory.set_keep_for_debug(): + debug_temp_dir = util.tempdir() + try: + validate_debug_dir_path(debug_temp_dir, '0000' + str(temp_dir_number)) + finally: + shutil.rmtree(debug_temp_dir.temp_dir) + + with util.TempDirectory.set_keep_for_debug(): + # Create 2 temp_dir within the same session. + debug_temp_dir = util.tempdir() + try: + validate_debug_dir_path(debug_temp_dir, '00003') + finally: + shutil.rmtree(debug_temp_dir.temp_dir) + + debug_temp_dir = util.tempdir() + try: + validate_debug_dir_path(debug_temp_dir, '00004') + finally: + shutil.rmtree(debug_temp_dir.temp_dir) + + with util.TempDirectory.set_keep_for_debug(False): + debug_temp_dir = util.tempdir() # This one should get deleted. + + # Simulate atexit hook + util.TempDirectory.remove_tempdirs() + + # Calling twice should be a no-op. + util.TempDirectory.remove_tempdirs() + + # Creating a new TempDirectory should fail now + try: + util.tempdir() + assert False, 'creation should fail' + except util.DirectoryCreatedPastAtExit: + pass + + finally: + util.TempDirectory.DEBUG_MODE = old_debug_mode + + +if __name__ == '__main__': + test_tempdir()