-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow #4464
Comments
The implementation of this proposal has been submit to #4459 . Anyone can try to test their TVM operators by re-compiling TVM with |
The motivations of this RFC are extremely similar to those in pytorch-tvm, however the two implementations are very different and it is worth discussing the tradeoffs.
I understand that the current implementation is the shortest path to getting tvm functions working in TensorFlow and that a torch-tvm approach would be a much larger undertaking. However, I don't think it will be able to scale well. The use of prebuilt libraries means there will be a lot of back and forth between regular tvm and tensorflow-tvm during development, and it seems like developers would be better off just importing their tf model to relay and doing everything within tvm. Contrast this to the torch-tvm approach where all the tvm magic happens transparently, making it very straight forward for pytorch users. We should also consider where the code belongs. I personally prefer having projects like torch-tvm and tf-tvm being separate from the main tvm repo if possible as it we already are dealing with frontend bloat. All that said, I think something like tf-tvm is a great idea and something we should work towards. I just want to make sure we make the first step carefully. |
Thanks @jwfromm and you're definitely right. This is the fastest way to integrate TVM functions into TensorFlow if we can not convert the whole model to TVM. This may be meaningful for TensorFlow developers if they can to try TVM and leverage the sub-graph optimizaition from TVM. Actually this project is the TensorFlow custom op with TVM runtime. We originally develop in the standalone project https://github.com/tobegit3hub/tftvm . Since it depends on TVM and TensorFlow to compile, it is okay to be one of the TVM contrib libraries or maintain in the independent project. |
That makes sense, you're right that having it in contrib clears up a lot of my concerns. Thanks for those clarifications! |
The PR has been merged and we will close this issue. |
@tobegit3hub Hi guys, i built tvm before and i built tvmsoop separately(not from USE_TF_TVMSOOP=ON) follow this:https://github.com/tobegit3hub/tftvm/tree/master/examples . import tensorflow as tf mod = tf_op.Module("tvm_addone_dll.so") with tf.Session() as sess: and i got this error: my gcc is 6.4.0,my tensorflow flow is tf-1.15.0 i use bazel build it from source and set -D_GLIBCXX_CXX11_ABI=1. |
@652994331 You should not use |
@tobegit3hub i tried before, but unfortunately, i had this error: |
@tobegit3hub and here is the entire cmake log: |
@652994331 You need to install Here is the error message from your cmake.
|
@tobegit3hub thanks for the reply, i used pip install and also build tensorflow from source , if i use import tensorflow as tf; tf.version i can find 1.15.0 in my env. i guess there's something wrong with the path? |
@tobegit3hub Maybe i should set up the tensorflow path in the config.cmake before i cmake tvm, sorry i am not pretty sure |
@tobegit3hub It seems there are some lines about tensorflow path from the cmakelist.txt of tftvm projects(which's deprecated like u said) https://github.com/tobegit3hub/tftvm/blob/master/CMakeLists.txt Thanks |
@tobegit3hub Hi, i checked the cmake file again, i think the problem is: the TF_TVMDSOOP.cmake file use findpackages() to find {python3_executable}, the path it found's /usr/local/bin/python, but actually i am using a anaconda3 env and i installed tensorflow in this env. i tried to set python3_executable to my anaconda python path, not worked. Could you please help me about this, thank you! |
Problem
TensorFlow is one of the most popular machine learning libraries and most developers are used to train/inference models with TensorFlow/TensorFlow Serving. TVM is the flexible compiler to run computation efficiently in different devices. Although TensorFlow has implemented some efficient GPU operators, developers can benifit from TVM to get more than 10 times speedup and FPGA support. But TensorFlow and TVM have two different code stacks and runtime APIs to use.
There are two ways to integrated TVM with TensorFlow. The first one is tensorflow-to-tvm which has been support by relay importer. Most TensorFlow operators can be “translated” to TVM operators which is useful if want to run the TVM stack with the model structure from other frameworks.
The second one is tvm-to-tensorflow. This requires to embed TVM operators in TensorFlow graph so that we can use TensorFlow session to run preset operators and TVM-optimized operators. This is really helpful if we want to use TVM to optimize part of the computation graph while developers can use TensorFlow Python API to describe the model and use TensorFlow Serving for inference. Embedding TVM in TensorFlow requires the minimal cost to use TVM optimiztion on existing models and extend TensorFlow functionalities such as FPGA support.
This RFC describes how we design to support tvm-to-tensorflow with TensorFlow custom op API and the detail of implementation.
Considerations
Developers can use TVM stack to build operators without limitation.
Developers can use TVM Python package to import and load TVM operators in TensorFlow graph.
Developers can specify the output_shape/output_dtype/target_device/memory_align for TVM operators.
Proposal
The best way to extends TensorFlow functionality is building the TensorFlow custom op for TVM runtime. We build the operator called
TVMDSOOp
and it has implemented CPU and GPU kernels to load any TVM dynamic library. We can run TensorFlow graph with this op which invokes TVM inference with zero-copy Tensor data. Here is the walk-through examples.Developer can implement the TVM operators with TVM Python API. All they need to do is exporting the dynamic libraries to local file system.
With the code in our pull-request, we will set
set(USE_TFOP ON)
and use CMake to build the TVM from scratch. It would generate thetvm_dso_op.so
file and provide thetvm.contrib.tf_op
in Python API. Then we can use TensorFlow and TVM to build the graph with TVM operators and run by TensorFlow session.Since every TensorFlow custom op should has specified input tensors, we wrap TVM Python API to support operators with up to 8 input tensors. Users can pass multiple TensorFlow tensors to TVMDSOOp if we support multiple inputs in TVM operators. The Python API looks the same as single input.
For more examples, please refer to https://github.com/tobegit3hub/tftvm/tree/master/examples .
All the TVM operators can be embedded into TensorFlow graph with this
TVMDSOOp
and Python API. We don't need to copy data from TensorFlow(Tensor) to TVM(DLPack) with zero-copy therefore the performance should be great.The text was updated successfully, but these errors were encountered: