-
Notifications
You must be signed in to change notification settings - Fork 7.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reloading export_torchscript_with_instances model #2582
Comments
Models exported by scripting cannot be loaded in C++ until pytorch/pytorch#46944 is fixed There are already examples to export the model by tracing and load it in C++: https://github.com/facebookresearch/detectron2/tree/master/tools/deploy |
The C++ example does not provide an example to load models created by export_torchscript_with_instances. As I can only use export_torchscript_with_instances, to convert RPN to a C++ model, using the tracing code to run the scripted model does not work (kind of expected). I get the following error on running scripted models: terminate called after throwing an instance of 'c10::Error' |
|
The example model provided above contains an RPN and it works. |
Ah ok! It says "can be loaded" in the previous answer, got confused by that! Will wait for the fix. Will try to output from the RPN part of the Faster RCNN. Thanks :) |
We have added a workaround in 22c5c01 to support reloading scripted models. |
Hi, Thanks for the workaround. |
We don't have an official example but I do have some working C++ code written before having the above workaround (but instead with a bug-fixed pytorch version):
#include <opencv2/opencv.hpp>
#include <iostream>
#include <string>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/script.h>
#include <torchvision/vision.h>
using namespace std;
c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) {
const int height = img.rows;
const int width = img.cols;
const int channels = 3;
auto input =
torch::from_blob(img.data, {height, width, channels}, torch::kUInt8);
// HWC to CHW
input = input.to(device, torch::kFloat).permute({2, 0, 1}).contiguous();
c10::Dict dic = c10::Dict<std::string, torch::Tensor>();
dic.insert("image", input);
return dic;
}
int main(int argc, const char* argv[]) {
if (argc != 3) {
cerr << R"xx(
Usage:
./torchscript_traced_mask_rcnn model.ts input.jpg
)xx";
return 1;
}
std::string image_file = argv[2];
torch::autograd::AutoGradMode guard(false);
auto module = torch::jit::load(argv[1]);
assert(module.buffers().size() > 0);
// Assume that the entire model is on the same device.
// We just put input to this device.
auto device = (*begin(module.buffers())).device();
cv::Mat input_img = cv::imread(image_file, cv::IMREAD_COLOR);
auto inputs = get_scripting_inputs(input_img, device);
// run the network
c10::Stack stack{
std::make_tuple(inputs),
c10::optional<c10::IValue>(),
/* do_postprocessing= */ false};
module.get_method("inference").run(stack);
auto outputs = stack.back();
} It might be useful for showing how to create inputs. |
Thanks a lot. Worked :) |
I am trying to get a C++ model for FasterRCNN and the RPN part of it.
Using export_torchscript_with_instances, I am able to get the output .pt models for both.
Could you please provide steps to reload this model in C++? Or alternatively, is there any way to convert the RPN model through tracing/scripting?
The text was updated successfully, but these errors were encountered: