diff --git a/python/tvm/contrib/tf_op/module.py b/python/tvm/contrib/tf_op/module.py index f13670e39895..fd0ed0c8fb2f 100644 --- a/python/tvm/contrib/tf_op/module.py +++ b/python/tvm/contrib/tf_op/module.py @@ -17,6 +17,7 @@ """Module container of TensorFlow TVMDSO op""" import tensorflow as tf from tensorflow.python.framework import load_library +from tensorflow.python import platform class OpModule: @@ -67,7 +68,7 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape): elif output_shape is not None: self.dynamic_output_shape = self._pack_shape_tensor(output_shape) - self.module = load_library.load_op_library('tvm_dso_op.so') + self.module = self._load_platform_specific_library("tvm_dso_op") self.tvm_dso_op = self.module.tvm_dso_op def apply(self, *params): @@ -82,6 +83,16 @@ def apply(self, *params): def __call__(self, *params): return self.apply(*params) + def _load_platform_specific_library(self, lib_name): + system = platform.system() + if system == "Darwin": + lib_file_name = lib_name + ".dylib" + elif system == "Windows": + lib_file_name = lib_name + ".dll" + else: + lib_file_name = lib_name + ".so" + return load_library.load_op_library(lib_file_name) + def _is_static_shape(self, shape): if shape is None or not isinstance(shape, list): return False