Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROS2 for heart anomaly detection #337

Merged
merged 8 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-_
# Copyright 2020-2022 OpenDR European Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import torch

import rclpy
from rclpy.node import Node
from vision_msgs.msg import Classification2D
from std_msgs.msg import Float32MultiArray

from opendr_ros2_bridge import ROS2Bridge
from opendr.perception.heart_anomaly_detection import GatedRecurrentUnitLearner, AttentionNeuralBagOfFeatureLearner


class HeartAnomalyNode(Node):

def __init__(self, input_ecg_topic="/ecg/ecg", output_heart_anomaly_topic="/opendr/heart_anomaly",
device="cuda", model="anbof"):
"""
Creates a ROS2 Node for heart anomaly (atrial fibrillation) detection from ecg data
:param input_ecg_topic: Topic from which we are reading the input array data
:type input_ecg_topic: str
:param output_heart_anomaly_topic: Topic to which we are publishing the predicted class
:type output_heart_anomaly_topic: str
:param device: device on which we are running inference ('cpu' or 'cuda')
:type device: str
:param model: model to use: anbof or gru
:type model: str
"""
super().__init__("heart_anomaly_detection_node")

self.publisher = self.create_publisher(Classification2D, output_heart_anomaly_topic, 1)

self.subscriber = self.create_subscription(Float32MultiArray, input_ecg_topic, self.callback, 1)

self.bridge = ROS2Bridge()

# AF dataset
self.channels = 1
self.series_length = 9000

if model == 'gru':
self.learner = GatedRecurrentUnitLearner(in_channels=self.channels, series_length=self.series_length,
n_class=4, device=device)
elif model == 'anbof':
self.learner = AttentionNeuralBagOfFeatureLearner(in_channels=self.channels, series_length=self.series_length,
n_class=4, device=device, attention_type='temporal')

self.learner.download(path='.', fold_idx=0)
self.learner.load(path='.')

self.get_logger().info("Heart anomaly detection node initialized.")

def callback(self, msg_data):
"""
Callback that process the input data and publishes to the corresponding topics
:param msg_data: input message
:type msg_data: std_msgs.msg.Float32MultiArray
"""
# Convert Float32MultiArray to OpenDR Timeseries
data = self.bridge.from_rosarray_to_timeseries(msg_data, self.channels, self.series_length)

# Run ecg classification
class_pred = self.learner.infer(data)

# Publish results
ros_class = self.bridge.from_category_to_rosclass(class_pred, self.get_clock().now().to_msg())
self.publisher.publish(ros_class)


def main(args=None):
rclpy.init(args=args)

parser = argparse.ArgumentParser()
parser.add_argument("--input_ecg_topic", type=str, default="/ecg/ecg",
help="listen to input ECG data on this topic")
parser.add_argument("--model", type=str, default="anbof", help="model to be used for prediction: anbof or gru",
choices=["anbof", "gru"])
parser.add_argument("--output_heart_anomaly_topic", type=str, default="/opendr/heart_anomaly",
help="Topic name for heart anomaly detection topic")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cpu, cuda)",
choices=["cuda", "cpu"])
args = parser.parse_args()

try:
if args.device == "cuda" and torch.cuda.is_available():
device = "cuda"
elif args.device == "cuda":
print("GPU not found. Using CPU instead.")
device = "cpu"
else:
print("Using CPU")
device = "cpu"
except:
print("Using CPU")
device = "cpu"

heart_anomaly_detection_node = HeartAnomalyNode(input_ecg_topic=args.input_ecg_topic,
output_heart_anomaly_topic=args.output_heart_anomaly_topic,
model=args.model, device=device)

rclpy.spin(heart_anomaly_detection_node)

heart_anomaly_detection_node.destroy_node()
rclpy.shutdown()


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion projects/opendr_ws_2/src/opendr_perception/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
'semantic_segmentation_bisenet = opendr_perception.semantic_segmentation_bisenet_node:main',
'face_recognition = opendr_perception.face_recognition_node:main',
'fall_detection = opendr_perception.fall_detection_node:main',
'video_activity_recognition = opendr_perception.video_activity_recognition_node:main'
'video_activity_recognition = opendr_perception.video_activity_recognition_node:main',
'heart_anomaly_detection = opendr_perception.heart_anomaly_detection_node:main',
],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

import numpy as np
from opendr.engine.data import Image
from opendr.engine.data import Image, Timeseries
from opendr.engine.target import Pose, BoundingBox, BoundingBoxList, Category

from cv_bridge import CvBridge
from std_msgs.msg import String
from std_msgs.msg import String, Header
from sensor_msgs.msg import Image as ImageMsg
from vision_msgs.msg import Detection2DArray, Detection2D, BoundingBox2D, ObjectHypothesis, ObjectHypothesisWithPose
from vision_msgs.msg import Detection2DArray, Detection2D, BoundingBox2D, ObjectHypothesis, ObjectHypothesisWithPose, \
Classification2D
from geometry_msgs.msg import Pose2D
from opendr_ros2_messages.msg import OpenDRPose2D, OpenDRPose2DKeypoint

Expand Down Expand Up @@ -280,3 +281,41 @@ def to_ros_category_description(self, category):
result = String()
result.data = category.description
return result

def from_rosarray_to_timeseries(self, ros_array, dim1, dim2):
"""
Converts ROS2 array into OpenDR Timeseries object
:param ros_array: data to be converted
:type ros_array: std_msgs.msg.Float32MultiArray
:param dim1: 1st dimension
:type dim1: int
:param dim2: 2nd dimension
:type dim2: int
:rtype: engine.data.Timeseries
"""
data = np.reshape(ros_array.data, (dim1, dim2))
data = Timeseries(data)
return data

def from_category_to_rosclass(self, prediction, timestamp, source_data=None):
"""
Converts OpenDR Category into Classification2D message with class label, confidence, timestamp and corresponding input
:param prediction: classification prediction
:type prediction: engine.target.Category
:param timestamp: time stamp for header message
:type timestamp: str
:param source_data: corresponding input or None
:return classification
:rtype: vision_msgs.msg.Classification2D
"""
classification = Classification2D()
classification.header = Header()
classification.header.stamp = timestamp

result = ObjectHypothesis()
result.id = str(prediction.data)
result.score = prediction.confidence
classification.results.append(result)
if source_data is not None:
classification.source_img = source_data
return classification