Skip to content

Commit

Permalink
Add debug mode to tempdir() (apache#5581)
Browse files Browse the repository at this point in the history
  • Loading branch information
areusch authored and Trevor Morris committed Jun 18, 2020
1 parent 07e6efd commit 7851ad5
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 4 deletions.
63 changes: 59 additions & 4 deletions python/tvm/contrib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,53 @@
# under the License.
"""Common system utilities"""
import atexit
import contextlib
import datetime
import os
import tempfile
import threading
import shutil
try:
import fcntl
except ImportError:
fcntl = None


class DirectoryCreatedPastAtExit(Exception):
"""Raised when a TempDirectory is created after the atexit hook runs."""

class TempDirectory(object):
"""Helper object to manage temp directory during testing.
Automatically removes the directory when it went out of scope.
"""

# When True, all TempDirectory are *NOT* deleted and instead live inside a predicable directory
# tree.
_KEEP_FOR_DEBUG = False

# In debug mode, each tempdir is named after the sequence
_NUM_TEMPDIR_CREATED = 0
_NUM_TEMPDIR_CREATED_LOCK = threading.Lock()
@classmethod
def _increment_num_tempdir_created(cls):
with cls._NUM_TEMPDIR_CREATED_LOCK:
to_return = cls._NUM_TEMPDIR_CREATED
cls._NUM_TEMPDIR_CREATED += 1

return to_return

_DEBUG_PARENT_DIR = None
@classmethod
def _get_debug_parent_dir(cls):
if cls._DEBUG_PARENT_DIR is None:
all_parents = f'{tempfile.gettempdir()}/tvm-debug-mode-tempdirs'
if not os.path.isdir(all_parents):
os.makedirs(all_parents)
cls._DEBUG_PARENT_DIR = tempfile.mkdtemp(
prefix=datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S___'), dir=all_parents)
return cls._DEBUG_PARENT_DIR

TEMPDIRS = set()
@classmethod
def remove_tempdirs(cls):
Expand All @@ -42,20 +75,42 @@ def remove_tempdirs(cls):

cls.TEMPDIRS = None

@classmethod
@contextlib.contextmanager
def set_keep_for_debug(cls, set_to=True):
"""Keep temporary directories past program exit for debugging."""
old_keep_for_debug = cls._KEEP_FOR_DEBUG
try:
cls._KEEP_FOR_DEBUG = set_to
yield
finally:
cls._KEEP_FOR_DEBUG = old_keep_for_debug

def __init__(self, custom_path=None):
if self.TEMPDIRS is None:
raise DirectoryCreatedPastAtExit()

self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG
if custom_path:
os.mkdir(custom_path)
self.temp_dir = custom_path
else:
self.temp_dir = tempfile.mkdtemp()
if self._created_with_keep_for_debug:
parent_dir = self._get_debug_parent_dir()
self.temp_dir = f'{parent_dir}/{self._increment_num_tempdir_created():05d}'
os.mkdir(self.temp_dir)
else:
self.temp_dir = tempfile.mkdtemp()

self.TEMPDIRS.add(self.temp_dir)
if not self._created_with_keep_for_debug:
self.TEMPDIRS.add(self.temp_dir)

def remove(self):
"""Remote the tmp dir"""
if self.temp_dir:
shutil.rmtree(self.temp_dir, ignore_errors=True)
self.TEMPDIRS.remove(self.temp_dir)
if not self._created_with_keep_for_debug:
shutil.rmtree(self.temp_dir, ignore_errors=True)
self.TEMPDIRS.remove(self.temp_dir)
self.temp_dir = None

def __del__(self):
Expand Down
86 changes: 86 additions & 0 deletions tests/python/contrib/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for functions in tvm/python/tvm/contrib/util.py."""

import datetime
import os
import shutil
from tvm.contrib import util


def validate_debug_dir_path(temp_dir, expected_basename):
dirname, basename = os.path.split(temp_dir.temp_dir)
assert basename == expected_basename, 'unexpected basename: %s' % (basename,)

parent_dir = os.path.basename(dirname)
create_time = datetime.datetime.strptime(parent_dir.split('___', 1)[0], '%Y-%m-%dT%H-%M-%S')
assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60)



def test_tempdir():
assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True"

temp_dir = util.tempdir()
assert os.path.exists(temp_dir.temp_dir)

old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG
try:
for temp_dir_number in range(0, 3):
with util.TempDirectory.set_keep_for_debug():
debug_temp_dir = util.tempdir()
try:
validate_debug_dir_path(debug_temp_dir, '0000' + str(temp_dir_number))
finally:
shutil.rmtree(debug_temp_dir.temp_dir)

with util.TempDirectory.set_keep_for_debug():
# Create 2 temp_dir within the same session.
debug_temp_dir = util.tempdir()
try:
validate_debug_dir_path(debug_temp_dir, '00003')
finally:
shutil.rmtree(debug_temp_dir.temp_dir)

debug_temp_dir = util.tempdir()
try:
validate_debug_dir_path(debug_temp_dir, '00004')
finally:
shutil.rmtree(debug_temp_dir.temp_dir)

with util.TempDirectory.set_keep_for_debug(False):
debug_temp_dir = util.tempdir() # This one should get deleted.

# Simulate atexit hook
util.TempDirectory.remove_tempdirs()

# Calling twice should be a no-op.
util.TempDirectory.remove_tempdirs()

# Creating a new TempDirectory should fail now
try:
util.tempdir()
assert False, 'creation should fail'
except util.DirectoryCreatedPastAtExit:
pass

finally:
util.TempDirectory.DEBUG_MODE = old_debug_mode


if __name__ == '__main__':
test_tempdir()

0 comments on commit 7851ad5

Please sign in to comment.