diff --git a/extension_cpp/__init__.py b/extension_cpp/__init__.py index 769c697..bbcf859 100644 --- a/extension_cpp/__init__.py +++ b/extension_cpp/__init__.py @@ -1,2 +1,6 @@ import torch +inside_virtual_environment = sys.prefix != sys.base_prefix +if inside_virtual_environment and os.name == 'nt': + dll_dir = os.path.join(sys.prefix, 'Lib\\site-packages\\torch\\lib') + handle = os.add_dll_directory(dll_dir) from . import _C, ops