Skip to content

Commit

Permalink
Merge pull request #1 from project-zetton/add_tracker
Browse files Browse the repository at this point in the history
Add SORT tracker
  • Loading branch information
corenel authored May 8, 2021
2 parents 460cbcc + 3720738 commit 5f939f9
Show file tree
Hide file tree
Showing 27 changed files with 3,362 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/util.cmake)

# find dependencies
find_package(Threads)
find_package(Eigen3 REQUIRED)
find_package(OpenCV 4 REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/yolo-tensorrt)

Expand Down Expand Up @@ -41,6 +42,7 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/external/yolo-tensorrt/include
${catkin_INCLUDE_DIRS}
${OpenCV_INCLUDE_DIRS}
${EIGEN3_INCLUDE_DIR}
)

# find all source files
Expand All @@ -61,6 +63,7 @@ target_link_libraries(
${catkin_LIBRARIES}
${OpenCV_LIBS}
Threads::Threads
Eigen3::Eigen
yolo_trt
)

Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ Object detection powered by YOLO-family algorithms.
rosrun zetton_inference example_rtsp_yolo_object_detector
```

### Object Tracking

#### SORT

Object tracking powered by SORT algorithms.

- Receive image form ROS topic, and then do detection & tracking:

```bash
rosrun zetton_inference example_ros_mot_tracker
```

#### MOT

Object tracking powered by Optical Flow & ReID.

(W.I.P)

## License

- For academic use, this project is licensed under the 2-clause BSD License - see the [LICENSE file](LICENSE) for details.
Expand Down
2 changes: 2 additions & 0 deletions asset/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*.weights
*.engine
*.pth
*.cfg
*.mp4
45 changes: 45 additions & 0 deletions config/object_tracker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
zetton_inference:
mot_tracker:
tracker:
track_fail_timeout_tick: 5
bbox_overlap_ratio: 0.6
detector_update_timeout_tick: 30
detector_bbox_padding: 10
reid_match_threshold: 3.0
reid_match_bbox_dis: 80
reid_match_bbox_size_diff: 80
stop_opt_timeout: 2

local_database:
height_width_ratio_min: 0.85
height_width_ratio_max: 4.0
record_interval: 0.1
feature_smooth_ratio: 0.7

kalman_filter:
q_xy: 100
q_wh: 25
p_xy_pos: 100
p_xy_dp: 10000
p_wh_size: 25
p_wh_ds: 25
r_theta: 0.08 # 0.02 rad 1 degree
r_f: 0.04
r_tx: 4
r_ty: 4
residual_threshold: 16

optical_flow:
min_keypoints_to_track: 10
keypoints_num_factor_area: 8000
corner_detector_max_num: 1000
corner_detector_quality_level: 0.06
corner_detector_min_distance: 1
corner_detector_block_size: 3
corner_detector_use_harris: false
corner_detector_k: 0.04
min_keypoints_to_cal_H_mat: 10
min_keypoints_for_motion_estimation: 50
min_pixel_dis_square_for_scene_point: 2
use_resize: true
resize_factor: 2
3 changes: 2 additions & 1 deletion example/ros_image_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def read_image(self):
ret, self.image = self.source.read()
if not ret:
self.source.set(cv2.CAP_PROP_POS_FRAMES, 0)
print('shift back to the begining')

def publish(self):
try:
Expand All @@ -65,7 +66,7 @@ def publish(self):

def start(self):
while not rospy.is_shutdown():
rospy.loginfo('publishing image')
# rospy.loginfo('publishing image')
if self.image is not None:
self.publish()
self.loop_rate.sleep()
Expand Down
131 changes: 131 additions & 0 deletions example/ros_mot_tracker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include <cv_bridge/cv_bridge.h>
#include <image_transport/image_transport.h>
#include <image_transport/publisher.h>
#include <ros/package.h>
#include <ros/ros.h>
#include <ros/time.h>
#include <sensor_msgs/image_encodings.h>

#include <csignal>
#include <opencv2/opencv.hpp>
#include <string>

#include "zetton_common/util/ros_util.h"
#include "zetton_inference/detector/yolo_object_detector.h"
#include "zetton_inference/tracker/mot_tracker.h"
#include "zetton_inference/tracker/sort_tracker.h"

void signalHandler(int sig) {
ROS_WARN("Trying to exit!");
ros::shutdown();
}

class RosMotTracker {
private:
inline void RosImageCallback(const sensor_msgs::ImageConstPtr& msg) {
// convert image msg to cv::Mat
cv_bridge::CvImagePtr cv_ptr;
try {
cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8);
} catch (cv_bridge::Exception& e) {
ROS_ERROR("cv_bridge exception: %s", e.what());
return;
}

// do detection
zetton::inference::ObjectDetectionResults detections;
detector_.Detect(cv_ptr->image, detections);

// do tracking
tracker_.Track(cv_ptr->image, ros::Time::now(), detections);

// print detections and tracks
ROS_INFO("Detections:");
for (auto& detection : detections) {
ROS_INFO_STREAM(detection);
detection.Draw(cv_ptr->image);
}
ROS_INFO("Trackings:");
for (auto& track : tracker_.tracks()) {
// if (track.tracking_fail_count <= 3) {
ROS_INFO_STREAM(track);
track.Draw(cv_ptr->image);
// }
}

// publish results
image_pub_.publish(
cv_bridge::CvImage(std_msgs::Header(), "bgr8", cv_ptr->image)
.toImageMsg());
ROS_INFO("---");
}

ros::NodeHandle* nh_;

image_transport::ImageTransport it_;
image_transport::Subscriber image_sub_;
image_transport::Publisher image_pub_;

zetton::inference::YoloObjectDetector detector_;
// zetton::inference::MotTracker tracker_;
zetton::inference::SortTracker tracker_;

public:
RosMotTracker(ros::NodeHandle* nh) : nh_(nh), it_(*nh_) {
// load params
// hardcoded or using GPARAM
// std::string image_topic_sub = "/uvds_communication/image_streaming/mavic_0";
std::string image_topic_sub = "/camera/image";

// subscribe to input video feed
image_sub_ = it_.subscribe(image_topic_sub, 1,
&RosMotTracker::RosImageCallback, this);
// publish images
image_pub_ = it_.advertise("/camera/result", 1);

// prepare yolo config
yolo_trt::Config config_v4;
std::string package_path = ros::package::getPath("zetton_inference");
// config_v4.net_type = yolo_trt::ModelType::YOLOV4_TINY;
// config_v4.file_model_cfg = package_path + "/asset/yolov4-tiny-uav.cfg";
// config_v4.file_model_weights =
// package_path + "/asset/yolov4-tiny-uav_best.weights";
// config_v4.net_type = yolo_trt::ModelType::YOLOV4;
// config_v4.file_model_cfg = package_path + "/asset/yolov4-608.cfg";
// config_v4.file_model_weights =
// package_path + "/asset/yolov4-608.weights";
config_v4.net_type = yolo_trt::ModelType::YOLOV4;
config_v4.file_model_cfg = package_path + "/asset/yolov4-visdrone.cfg";
config_v4.file_model_weights =
package_path + "/asset/yolov4-visdrone-best.weights";
config_v4.inference_precision = yolo_trt::Precision::FP32;
config_v4.detect_thresh = 0.4;

// initialize detector
detector_.Init(config_v4);
detector_.SetWidthLimitation(50, 1920);
detector_.SetHeightLimitation(50, 1920);

// initialize tracker
tracker_.Init();
}

~RosMotTracker() {
if (nh_) delete nh_;
}
};

int main(int argc, char** argv) {
// init node
ros::init(argc, argv, "example_ros_mot_tracker");
auto nh = new ros::NodeHandle("~");

// catch external interrupt initiated by the user and exit program
signal(SIGINT, signalHandler);

// init instance
RosMotTracker pipeline(nh);

ros::spin();
return 0;
}
25 changes: 25 additions & 0 deletions include/zetton_inference/interface/base_object_tracker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <iostream>
#include <utility>
#include <vector>

#include "opencv2/opencv.hpp"
#include "zetton_common/util/registerer.h"
#include "zetton_inference/interface/base_inference.h"

namespace zetton {
namespace inference {

class BaseObjectTracker : public BaseInference {
public:
void Infer() override = 0;
virtual bool Track() = 0;

ZETTON_REGISTER_REGISTERER(BaseObjectTracker);
#define ZETTON_REGISTER_OBJECT_TRACKER(name) \
ZETTON_REGISTER_CLASS(BaseObjectTracker, name)
};

} // namespace inference
} // namespace zetton
Loading

0 comments on commit 5f939f9

Please sign in to comment.