Sometimes, we might need to convert PyTorch model to TFLite format in order to facilitate the deployment of the model on the device side. The existing conversion method usually takes the following procedures.
- Convert to an ONNX model via
torch.onnx.export
- Convert to a Tensorflow frozen model via onnx2tensorflow
- Convert to a TFLite model via
tensorflow.lite.TFLiteConverter
This method has the following shortcomings.
- The conversion is a lengthy process and often lead to problems
- The conversion of quantized models is not supported
- The models with LSTM cannot be converted
- The model converted with onnx2tf has many redundant OPs
To solve the above problems, we implement this converter that translates models from PyTorch to TFLite directly.
- Support for PyTorch 1.6+
- Support for quantized models
- Support for the LSTM op
- A lot of optimization pass including continuous transpose, reshape elimination, no-op removal and etc.
- Written in 100% Python, which is easy to maintain
- operators: Most of the components of the converter
- tflite : TFLite related classes
- base.py : TFLite base data structure
- custom.py : TFLite custom operators
- generated_ops.py : Wrapper class generated from TFLite schema
- transformable.py : Transformable operators, such as BatchNorm, Conv2d, and other composite operators composed of multiple TFLite operators
- torch : PyTorch related classes
- base.py : The base data structure needed for TorchScript parsing
- aten.py : Translation of ATen-related operators
- quantized.py : Translation of quantized-related operators
- base.py : Definition of generic operators
- graph.py : Computation of graph-related infrastructure
- op_version.py : Handler for operator version
- optimize.py : Computation graph optimization
- tflite : TFLite related classes
- schemas: Most of the schemas of the converter
- tflite : TFLite related schemas
- schema_generated.py : TFLite schema parsers
- torch : PyTorch related schemas
- aten_schema.py : Wrapper classes generated from ATen schema
- quantized_schema.py : Wrapper class generated from quantized schema
- torchvision_schema.py : Wrapper class torchvision_schema from Torchvision schema
- tflite : TFLite related schemas
- base.py: Entry class
TFLiteConverter