diff --git a/deepmd/tf/lmp.py b/deepmd/tf/lmp.py index b2e47308ed..e378beecfe 100644 --- a/deepmd/tf/lmp.py +++ b/deepmd/tf/lmp.py @@ -6,6 +6,9 @@ from importlib import ( import_module, ) +from importlib.util import ( + find_spec, +) from pathlib import ( Path, ) @@ -77,6 +80,11 @@ def get_library_path(module: str, filename: str) -> List[str]: tf_dir = tf.sysconfig.get_lib() op_dir = str(SHARED_LIB_DIR) +pt_spec = find_spec("torch") +if pt_spec is not None: + pt_dir = pt_spec.submodule_search_locations[0] +else: + pt_dir = None cuda_library_paths = [] @@ -106,6 +114,7 @@ def get_library_path(module: str, filename: str) -> List[str]: [ os.environ.get(lib_env), tf_dir, + pt_dir, os.path.join(tf_dir, "python"), op_dir, ]