Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reloading export_torchscript_with_instances model #2582

Closed
kg512 opened this issue Feb 3, 2021 · 9 comments
Closed

Reloading export_torchscript_with_instances model #2582

kg512 opened this issue Feb 3, 2021 · 9 comments
Labels
upstream issues Issues in other libraries

Comments

@kg512
Copy link

kg512 commented Feb 3, 2021

I am trying to get a C++ model for FasterRCNN and the RPN part of it.

Using export_torchscript_with_instances, I am able to get the output .pt models for both.

Could you please provide steps to reload this model in C++? Or alternatively, is there any way to convert the RPN model through tracing/scripting?

@kg512 kg512 added the enhancement Improvements or good new features label Feb 3, 2021
@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Feb 3, 2021

Models exported by scripting cannot be loaded in C++ until pytorch/pytorch#46944 is fixed

There are already examples to export the model by tracing and load it in C++: https://github.com/facebookresearch/detectron2/tree/master/tools/deploy

@ppwwyyxx ppwwyyxx closed this as completed Feb 3, 2021
@ppwwyyxx ppwwyyxx added upstream issues Issues in other libraries and removed enhancement Improvements or good new features labels Feb 3, 2021
@kg512
Copy link
Author

kg512 commented Feb 3, 2021

The C++ example does not provide an example to load models created by export_torchscript_with_instances.

As I can only use export_torchscript_with_instances, to convert RPN to a C++ model, using the tracing code to run the scripted model does not work (kind of expected).

I get the following error on running scripted models:

terminate called after throwing an instance of 'c10::Error'
what(): forward() Expected a value of type 'Tuple[Dict[str, Tensor]]' for argument 'batched_inputs' but instead found type 'Tensor'.
Position: 1
Declaration: forward(torch.detectron2.modeling.meta_arch.rcnn.GeneralizedRCNN self, (Dict(str, Tensor)) batched_inputs) -> (torch.detectron2.export.torchscript_patch1.ScriptedInstances1[])
Exception raised from checkArg at /opt/conda/conda-bld/pytorch_1612253278703/work/aten/src/ATen/core/function_schema_inl.h:162 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f4a5019a8d2 in /home/kratika/anaconda3/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x7f4a50197f1b in /home/kratika/anaconda3/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #2: + 0xaf473f (0x7f4a4840873f in /home/kratika/anaconda3/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::jit::GraphFunction::operator()(std::vector<c10::IValue, std::allocatorc10::IValue >, std::unordered_map<std::string, c10::IValue, std::hashstd::string, std::equal_tostd::string, std::allocator<std::pair<std::string const, c10::IValue> > > const&) + 0x2d (0x7f4a4a85a1bd in /home/kratika/anaconda3/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::jit::Method::operator()(std::vector<c10::IValue, std::allocatorc10::IValue >, std::unordered_map<std::string, c10::IValue, std::hashstd::string, std::equal_tostd::string, std::allocator<std::pair<std::string const, c10::IValue> > > const&) + 0x138 (0x7f4a4a8674d8 in /home/kratika/anaconda3/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: ./torchscript_traced_mask_rcnn() [0x430999]
frame #6: ./torchscript_traced_mask_rcnn() [0x427b30]
frame #7: __libc_start_main + 0xf0 (0x7f4a088a6840 in /lib/x86_64-linux-gnu/libc.so.6)
frame #8: ./torchscript_traced_mask_rcnn() [0x426b89]

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Feb 3, 2021

export_torchscript_with_instances uses scripting, so as I said above it cannot be loaded in C++ now

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Feb 3, 2021

using the tracing code to run the scripted model does not work (kind of expected)

The example model provided above contains an RPN and it works.

@kg512
Copy link
Author

kg512 commented Feb 3, 2021

Ah ok! It says "can be loaded" in the previous answer, got confused by that! Will wait for the fix.

Will try to output from the RPN part of the Faster RCNN. Thanks :)

@ppwwyyxx
Copy link
Contributor

We have added a workaround in 22c5c01 to support reloading scripted models.

@kg512
Copy link
Author

kg512 commented Apr 7, 2021

Hi,

Thanks for the workaround.
I am still having trouble loading the inputs in C++. It expects Tuple[Dict[str, Tensor]] as input. Is there any example code on how to transform the input to load into the scripted model?

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Apr 7, 2021

We don't have an official example but I do have some working C++ code written before having the above workaround (but instead with a bug-fixed pytorch version):

	

#include <opencv2/opencv.hpp>
#include <iostream>
#include <string>

#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/script.h>

#include <torchvision/vision.h>

using namespace std;

c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) {
  const int height = img.rows;
  const int width = img.cols;
  const int channels = 3;

  auto input =
      torch::from_blob(img.data, {height, width, channels}, torch::kUInt8);
  // HWC to CHW
  input = input.to(device, torch::kFloat).permute({2, 0, 1}).contiguous();

  c10::Dict dic = c10::Dict<std::string, torch::Tensor>();
  dic.insert("image", input);
  return dic;
}

int main(int argc, const char* argv[]) {
  if (argc != 3) {
    cerr << R"xx(
Usage:
   ./torchscript_traced_mask_rcnn model.ts input.jpg
)xx";
    return 1;
  }
  std::string image_file = argv[2];

  torch::autograd::AutoGradMode guard(false);
  auto module = torch::jit::load(argv[1]);

  assert(module.buffers().size() > 0);
  // Assume that the entire model is on the same device.
  // We just put input to this device.
  auto device = (*begin(module.buffers())).device();

  cv::Mat input_img = cv::imread(image_file, cv::IMREAD_COLOR);
  auto inputs = get_scripting_inputs(input_img, device);

  // run the network
  c10::Stack stack{
    std::make_tuple(inputs),
      c10::optional<c10::IValue>(),
      /* do_postprocessing= */ false};
  module.get_method("inference").run(stack);
  auto outputs = stack.back();
}

It might be useful for showing how to create inputs.

@kg512
Copy link
Author

kg512 commented Apr 7, 2021

Thanks a lot. Worked :)

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Feb 2, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
upstream issues Issues in other libraries
Projects
None yet
Development

No branches or pull requests

2 participants