-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bugfix] Allow import of TVM when current directory is read-only
Prior to this commit, TVM could only be imported if the current directory had write privileges. This was due to the use of `tvm.contrib.pickle_memoize` to cache the winograd transformation matrices. This commit makes multiple related fixes, to ensure that (1) TVM can be imported regardless of directory permissions, (2) the working directory is not left in a cluttered state, and (3) cache files are generated in an expected location to be reused later. * The cache directory is only generated when required, just prior to saving a cache. * The cache directory defaults to `$HOME/.cache/tvm/pkl_memoize`, rather than `.pkl_memorize_py3` in the working directory. * The cache directory respects `XDG_CACHE_HOME`, using `$XDG_CACHE_HOME/tvm/pkl_memoize` if set.
- Loading branch information
1 parent
0fc047c
commit 110314b
Showing
3 changed files
with
212 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import sys | ||
|
||
import tvm | ||
|
||
|
||
@tvm.contrib.pickle_memoize.memoize("test_memoize_save_data", save_at_exit=True) | ||
def get_data_saved(): | ||
return 42 | ||
|
||
|
||
@tvm.contrib.pickle_memoize.memoize("test_memoize_transient_data", save_at_exit=False) | ||
def get_data_transient(): | ||
return 42 | ||
|
||
|
||
def main(): | ||
assert len(sys.argv) == 3, "Expect arguments SCRIPT NUM_SAVED NUM_TRANSIENT" | ||
|
||
num_iter_saved = int(sys.argv[1]) | ||
num_iter_transient = int(sys.argv[2]) | ||
|
||
for _ in range(num_iter_saved): | ||
get_data_saved() | ||
for _ in range(num_iter_transient): | ||
get_data_transient() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Tests for tvm.contrib.pickle_memoize""" | ||
|
||
import os | ||
import pathlib | ||
import tempfile | ||
import subprocess | ||
import sys | ||
|
||
import tvm.testing | ||
|
||
TEST_SCRIPT_FILE = pathlib.Path(__file__).with_name("pickle_memoize_script.py").resolve() | ||
|
||
|
||
def test_cache_dir_not_in_current_working_dir(): | ||
with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: | ||
temp_dir = pathlib.Path(temp_dir) | ||
subprocess.check_call([TEST_SCRIPT_FILE, "1", "1"], cwd=temp_dir) | ||
|
||
new_files = list(temp_dir.iterdir()) | ||
assert ( | ||
not new_files | ||
), "Use of tvm.contrib.pickle_memorize may not write to current directory." | ||
|
||
|
||
def test_current_directory_is_not_required_to_be_writable(): | ||
"""TVM may be imported without directory permissions | ||
This is a regression test. In previous implementations, the | ||
`tvm.contrib.pickle_memoize.memoize` function would write to the | ||
current directory when importing TVM. Import of a Python module | ||
should not write to any directory. | ||
""" | ||
|
||
with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: | ||
temp_dir = pathlib.Path(temp_dir) | ||
|
||
# User may read/cd into the temp dir, nobody may write to temp | ||
# dir. | ||
temp_dir.chmod(0o500) | ||
subprocess.check_call([sys.executable, "-c", "import tvm"], cwd=temp_dir) | ||
|
||
|
||
def test_cache_dir_defaults_to_home_config_cache(): | ||
with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: | ||
temp_dir = pathlib.Path(temp_dir) | ||
|
||
subprocess.check_call([TEST_SCRIPT_FILE, "1", "0"], cwd=temp_dir) | ||
|
||
new_files = list(temp_dir.iterdir()) | ||
assert ( | ||
not new_files | ||
), "Use of tvm.contrib.pickle_memorize may not write to current directory." | ||
|
||
cache_dir = pathlib.Path.home().joinpath(".cache", "tvm", "pkl_memoize_py3") | ||
assert cache_dir.exists() | ||
cache_files = list(cache_dir.iterdir()) | ||
assert len(cache_files) >= 1 | ||
|
||
|
||
def test_cache_dir_respects_xdg_cache_home(): | ||
with tempfile.TemporaryDirectory( | ||
prefix="tvm_" | ||
) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: | ||
temp_cache_dir = pathlib.Path(temp_cache_dir) | ||
temp_working_dir = pathlib.Path(temp_working_dir) | ||
|
||
subprocess.check_call( | ||
[TEST_SCRIPT_FILE, "1", "0"], | ||
cwd=temp_working_dir, | ||
env={ | ||
**os.environ, | ||
"XDG_CACHE_HOME": temp_cache_dir.as_posix(), | ||
}, | ||
) | ||
|
||
new_files = list(temp_working_dir.iterdir()) | ||
assert ( | ||
not new_files | ||
), "Use of tvm.contrib.pickle_memorize may not write to current directory." | ||
|
||
cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") | ||
assert cache_dir.exists() | ||
cache_files = list(cache_dir.iterdir()) | ||
assert len(cache_files) == 1 | ||
|
||
|
||
def test_cache_dir_only_created_when_used(): | ||
with tempfile.TemporaryDirectory( | ||
prefix="tvm_" | ||
) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: | ||
temp_cache_dir = pathlib.Path(temp_cache_dir) | ||
temp_working_dir = pathlib.Path(temp_working_dir) | ||
|
||
subprocess.check_call( | ||
[TEST_SCRIPT_FILE, "0", "1"], | ||
cwd=temp_working_dir, | ||
env={ | ||
**os.environ, | ||
"XDG_CACHE_HOME": temp_cache_dir.as_posix(), | ||
}, | ||
) | ||
|
||
cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") | ||
assert not cache_dir.exists() | ||
|
||
|
||
if __name__ == "__main__": | ||
tvm.testing.main() |