diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index 6d2ffbac0673..10f3f7c6df04 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Memoize result of function via pickle, used for cache testcases.""" + # pylint: disable=broad-except,superfluous-parens +import atexit import os +import pathlib import sys -import atexit + from decorator import decorate from .._ffi.base import string_types @@ -28,6 +31,17 @@ import pickle +def _get_global_cache_dir() -> pathlib.Path: + if "XDG_CACHE_HOME" in os.environ: + cache_home = pathlib.Path(os.environ.get("XDG_CACHE_HOME")) + else: + cache_home = pathlib.Path.home().joinpath(".cache") + return cache_home.joinpath("tvm", f"pkl_memoize_py{sys.version_info[0]}") + + +GLOBAL_CACHE_DIR = _get_global_cache_dir() + + class Cache(object): """A cache object for result cache. @@ -42,28 +56,34 @@ class Cache(object): cache_by_key = {} def __init__(self, key, save_at_exit): - cache_dir = f".pkl_memoize_py{sys.version_info[0]}" - try: - os.mkdir(cache_dir) - except FileExistsError: - pass - else: - self.cache = {} - self.path = os.path.join(cache_dir, key) - if os.path.exists(self.path): - try: - self.cache = pickle.load(open(self.path, "rb")) - except Exception: - self.cache = {} - else: - self.cache = {} + self._cache = None + + self.path = GLOBAL_CACHE_DIR.joinpath(key) self.dirty = False self.save_at_exit = save_at_exit + @property + def cache(self): + if self._cache is not None: + return self._cache + + if self.path.exists(): + with self.path.open("rb") as cache_file: + try: + cache = pickle.load(cache_file) + except pickle.UnpicklingError: + cache = {} + else: + cache = {} + + self._cache = cache + return self._cache + def save(self): if self.dirty: - print(f"Save memoize result to {self.path}") - with open(self.path, "wb") as out_file: + self.path.parent.mkdir(parents=True, exist_ok=True) + + with self.path.open("wb") as out_file: pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL) diff --git a/tests/python/contrib/pickle_memoize_script.py b/tests/python/contrib/pickle_memoize_script.py new file mode 100755 index 000000000000..f0d73e391066 --- /dev/null +++ b/tests/python/contrib/pickle_memoize_script.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import tvm + + +@tvm.contrib.pickle_memoize.memoize("test_memoize_save_data", save_at_exit=True) +def get_data_saved(): + return 42 + + +@tvm.contrib.pickle_memoize.memoize("test_memoize_transient_data", save_at_exit=False) +def get_data_transient(): + return 42 + + +def main(): + assert len(sys.argv) == 3, "Expect arguments SCRIPT NUM_SAVED NUM_TRANSIENT" + + num_iter_saved = int(sys.argv[1]) + num_iter_transient = int(sys.argv[2]) + + for _ in range(num_iter_saved): + get_data_saved() + for _ in range(num_iter_transient): + get_data_transient() + + +if __name__ == "__main__": + main() diff --git a/tests/python/contrib/test_memoize.py b/tests/python/contrib/test_memoize.py new file mode 100644 index 000000000000..6881940e5062 --- /dev/null +++ b/tests/python/contrib/test_memoize.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for tvm.contrib.pickle_memoize""" + +import os +import pathlib +import tempfile +import subprocess +import sys + +import tvm.testing + +TEST_SCRIPT_FILE = pathlib.Path(__file__).with_name("pickle_memoize_script.py").resolve() + + +def test_cache_dir_not_in_current_working_dir(): + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + subprocess.check_call([TEST_SCRIPT_FILE, "1", "1"], cwd=temp_dir) + + new_files = list(temp_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + +def test_current_directory_is_not_required_to_be_writable(): + """TVM may be imported without directory permissions + + This is a regression test. In previous implementations, the + `tvm.contrib.pickle_memoize.memoize` function would write to the + current directory when importing TVM. Import of a Python module + should not write to any directory. + + """ + + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + # User may read/cd into the temp dir, nobody may write to temp + # dir. + temp_dir.chmod(0o500) + subprocess.check_call([sys.executable, "-c", "import tvm"], cwd=temp_dir) + + +def test_cache_dir_defaults_to_home_config_cache(): + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + subprocess.check_call([TEST_SCRIPT_FILE, "1", "0"], cwd=temp_dir) + + new_files = list(temp_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + cache_dir = pathlib.Path.home().joinpath(".cache", "tvm", "pkl_memoize_py3") + assert cache_dir.exists() + cache_files = list(cache_dir.iterdir()) + assert len(cache_files) >= 1 + + +def test_cache_dir_respects_xdg_cache_home(): + with tempfile.TemporaryDirectory( + prefix="tvm_" + ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: + temp_cache_dir = pathlib.Path(temp_cache_dir) + temp_working_dir = pathlib.Path(temp_working_dir) + + subprocess.check_call( + [TEST_SCRIPT_FILE, "1", "0"], + cwd=temp_working_dir, + env={ + **os.environ, + "XDG_CACHE_HOME": temp_cache_dir.as_posix(), + }, + ) + + new_files = list(temp_working_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") + assert cache_dir.exists() + cache_files = list(cache_dir.iterdir()) + assert len(cache_files) == 1 + + +def test_cache_dir_only_created_when_used(): + with tempfile.TemporaryDirectory( + prefix="tvm_" + ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: + temp_cache_dir = pathlib.Path(temp_cache_dir) + temp_working_dir = pathlib.Path(temp_working_dir) + + subprocess.check_call( + [TEST_SCRIPT_FILE, "0", "1"], + cwd=temp_working_dir, + env={ + **os.environ, + "XDG_CACHE_HOME": temp_cache_dir.as_posix(), + }, + ) + + cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") + assert not cache_dir.exists() + + +if __name__ == "__main__": + tvm.testing.main()