Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export model for inference in C++ #104

Open
AlexejD opened this issue Jun 8, 2018 · 0 comments
Open

Export model for inference in C++ #104

AlexejD opened this issue Jun 8, 2018 · 0 comments

Comments

@AlexejD
Copy link

AlexejD commented Jun 8, 2018

Hi all,

first of all thanks for sharing this much content on your model.

I want to do inference with squeezeDet inside a C++ project (roscpp). For this I need to load the model from model checkpoints and model data, this is done by using the function "ReadBinaryProto" of the tensorflow C++ api. Afterwards I create a tensorflow session and run it to do the inference.

This works fine if I use the model checkpoints that you provided for the demo.
Input tensors: image_input
Output tensors: bbox/trimming/bbox, probability/score, probaibility_class_idx
(I also needed to define a tensor to feed in a value for keep_prob to make it work.)

However, for inference, I wanted to set the batch size to 1. So I modified the provided demo.py by setting the batch size to 1 and after loading the model it will be saved to disk again using the tensorflow saver.

After saving I can load the modified model inside my C++ program and build it. At runtime I am not getting any errors but at the point where the inference is being called the execution "gets stuck" for forever without throwing any errors.

So my question is:
@BichenWuUCB How did you save the model you provide for use with the demo.py? Why will my approach with simply using the tensorflow saver not work?

Here is the modified code of the image_demo() function of the demo.py

def image_demo():
  """Detect image."""

  assert FLAGS.demo_net == 'squeezeDet' or FLAGS.demo_net == 'squeezeDet+', \
      'Selected nueral net architecture not supported: {}'.format(FLAGS.demo_net)

  with tf.Graph().as_default():
    # Load model
    if FLAGS.demo_net == 'squeezeDet':
      mc = kitti_squeezeDet_config()
      # set batch size to 1 for inference
      mc.BATCH_SIZE = 1
      # model parameters will be restored from checkpoint
      mc.LOAD_PRETRAINED_MODEL = False
      model = SqueezeDet(mc, FLAGS.gpu)
    elif FLAGS.demo_net == 'squeezeDet+':
      mc = kitti_squeezeDetPlus_config()
      mc.BATCH_SIZE = 1
      mc.LOAD_PRETRAINED_MODEL = False
      model = SqueezeDetPlus(mc, FLAGS.gpu)

    saver = tf.train.Saver(model.model_params)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
      saver.restore(sess, FLAGS.checkpoint)

      # Save graph metadata and checkpoint
      saver.save(sess, './data/out/checkpoint.ckpt')

      for f in glob.iglob(FLAGS.input_path):
        im = cv2.imread(f)
        im = im.astype(np.float32, copy=False)
        im = cv2.resize(im, (mc.IMAGE_WIDTH, mc.IMAGE_HEIGHT))
        input_image = im - mc.BGR_MEANS

        # Detect
        det_boxes, det_probs, det_class = sess.run(
            [model.det_boxes, model.det_probs, model.det_class],
            feed_dict={model.image_input:[input_image]})

        # Filter
        final_boxes, final_probs, final_class = model.filter_prediction(
            det_boxes[0], det_probs[0], det_class[0])

        keep_idx    = [idx for idx in range(len(final_probs)) \
                          if final_probs[idx] > mc.PLOT_PROB_THRESH]
        final_boxes = [final_boxes[idx] for idx in keep_idx]
        final_probs = [final_probs[idx] for idx in keep_idx]
        final_class = [final_class[idx] for idx in keep_idx]

        # TODO(bichen): move this color dict to configuration file
        cls2clr = {
            'car': (255, 191, 0),
            'cyclist': (0, 191, 255),
            'pedestrian':(255, 0, 191)
        }

        # Draw boxes
        _draw_box(
            im, final_boxes,
            [mc.CLASS_NAMES[idx]+': (%.2f)'% prob \
                for idx, prob in zip(final_class, final_probs)],
            cdict=cls2clr,
        )

        file_name = os.path.split(f)[1]
        out_file_name = os.path.join(FLAGS.out_dir, 'out_'+file_name)
        cv2.imwrite(out_file_name, im)
        print ('Image detection output saved to {}'.format(out_file_name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant