Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add object detection (iOS) #49

Merged
merged 21 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ios/RnExecutorch/ObjectDetection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <RnExecutorchSpec/RnExecutorchSpec.h>

@interface ObjectDetection : NSObject <NativeObjectDetectionSpec>

@end
56 changes: 56 additions & 0 deletions ios/RnExecutorch/ObjectDetection.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#import "ObjectDetection.h"
#import "models/object_detection/SSDLiteLargeModel.hpp"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import "utils/ImageProcessor.h"

@implementation ObjectDetection {
SSDLiteLargeModel *model;
}

RCT_EXPORT_MODULE()

- (void)loadModule:(NSString *)modelSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
model = [[SSDLiteLargeModel alloc] init];
[model loadModel:[NSURL URLWithString:modelSource]
completion:^(BOOL success, NSNumber *errorCode) {
if (success) {
resolve(errorCode);
return;
}

NSError *error = [NSError
errorWithDomain:@"StyleTransferErrorDomain"
code:[errorCode intValue]
userInfo:@{
NSLocalizedDescriptionKey : [NSString
stringWithFormat:@"%ld", (long)[errorCode longValue]]
}];

reject(@"init_module_error", error.localizedDescription, error);
return;
}];
}

- (void)forward:(NSString *)input
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
cv::Mat image = [ImageProcessor readImage:input];
NSArray *result = [model runModel:image];
resolve(result);
} @catch (NSException *exception) {
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
nil);
}
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeObjectDetectionSpecJSI>(
params);
}

@end
11 changes: 11 additions & 0 deletions ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#import "../BaseModel.h"
#import <UIKit/UIKit.h>
#include <opencv2/opencv.hpp>

@interface SSDLiteLargeModel : BaseModel

- (NSArray *)runModel:(cv::Mat)input;
- (NSArray *)preprocess:(cv::Mat)input;
- (NSArray *)postprocess:(NSArray *)input;

@end
65 changes: 65 additions & 0 deletions ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "SSDLiteLargeModel.hpp"
#include "../../utils/ObjectDetectionUtils.hpp"
#include "ImageProcessor.h"
#include <vector>

float constexpr iouThreshold = 0.55;
float constexpr detectionThreshold = 0.7;
int constexpr modelInputWidth = 320;
int constexpr modelInputHeight = 320;

@implementation SSDLiteLargeModel

- (NSArray *)preprocess:(cv::Mat)input {
cv::resize(input, input, cv::Size(modelInputWidth, modelInputHeight));
NSArray *modelInput = [ImageProcessor matToNSArray:input];
return modelInput;
}

- (NSArray *)postprocess:(NSArray *)input
widthRatio:(float)widthRatio
heightRatio:(float)heightRatio {
NSArray *bboxes = [input objectAtIndex:0];
NSArray *scores = [input objectAtIndex:1];
NSArray *labels = [input objectAtIndex:2];

std::vector<Detection> detections;

for (NSUInteger idx = 0; idx < scores.count; idx++) {
float score = [scores[idx] floatValue];
float label = [labels[idx] floatValue];
if (score < detectionThreshold) {
continue;
}
float x1 = [bboxes[idx * 4] floatValue] * widthRatio;
float y1 = [bboxes[idx * 4 + 1] floatValue] * heightRatio;
float x2 = [bboxes[idx * 4 + 2] floatValue] * widthRatio;
float y2 = [bboxes[idx * 4 + 3] floatValue] * heightRatio;

Detection det = {x1, y1, x2, y2, label, score};
detections.push_back(det);
}
std::vector<Detection> nms_output = nms(detections, iouThreshold);

NSMutableArray *output = [NSMutableArray array];
for (Detection &detection : nms_output) {
[output addObject:detectionToNSDictionary(detection)];
}

return output;
}

- (NSArray *)runModel:(cv::Mat)input {
cv::Size size = input.size();
int inputImageWidth = size.width;
int inputImageHeight = size.height;
NSArray *modelInput = [self preprocess:input];
NSArray *forwardResult = [self forward:modelInput];
NSArray *output =
[self postprocess:forwardResult
widthRatio:inputImageWidth / (float)modelInputWidth
heightRatio:inputImageHeight / (float)modelInputHeight];
return output;
}

@end
34 changes: 34 additions & 0 deletions ios/RnExecutorch/utils/Constants.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "Constants.h"

const std::unordered_map<int, std::string> cocoLabelsMap = {
{1, "PERSON"}, {2, "BICYCLE"}, {3, "CAR"},
{4, "MOTORCYCLE"}, {5, "AIRPLANE"}, {6, "BUS"},
{7, "TRAIN"}, {8, "TRUCK"}, {9, "BOAT"},
{10, "TRAFFIC_LIGHT"}, {11, "FIRE_HYDRANT"}, {12, "STREET_SIGN"},
{13, "STOP_SIGN"}, {14, "PARKING"}, {15, "BENCH"},
{16, "BIRD"}, {17, "CAT"}, {18, "DOG"},
{19, "HORSE"}, {20, "SHEEP"}, {21, "COW"},
{22, "ELEPHANT"}, {23, "BEAR"}, {24, "ZEBRA"},
{25, "GIRAFFE"}, {26, "HAT"}, {27, "BACKPACK"},
{28, "UMBRELLA"}, {29, "SHOE"}, {30, "EYE"},
{31, "HANDBAG"}, {32, "TIE"}, {33, "SUITCASE"},
{34, "FRISBEE"}, {35, "SKIS"}, {36, "SNOWBOARD"},
{37, "SPORTS"}, {38, "KITE"}, {39, "BASEBALL"},
{40, "BASEBALL"}, {41, "SKATEBOARD"}, {42, "SURFBOARD"},
{43, "TENNIS_RACKET"}, {44, "BOTTLE"}, {45, "PLATE"},
{46, "WINE_GLASS"}, {47, "CUP"}, {48, "FORK"},
{49, "KNIFE"}, {50, "SPOON"}, {51, "BOWL"},
{52, "BANANA"}, {53, "APPLE"}, {54, "SANDWICH"},
{55, "ORANGE"}, {56, "BROCCOLI"}, {57, "CARROT"},
{58, "HOT_DOG"}, {59, "PIZZA"}, {60, "DONUT"},
{61, "CAKE"}, {62, "CHAIR"}, {63, "COUCH"},
{64, "POTTED_PLANT"}, {65, "BED"}, {66, "MIRROR"},
{67, "DINING_TABLE"}, {68, "WINDOW"}, {69, "DESK"},
{70, "TOILET"}, {71, "DOOR"}, {72, "TV"},
{73, "LAPTOP"}, {74, "MOUSE"}, {75, "REMOTE"},
{76, "KEYBOARD"}, {77, "CELL_PHONE"}, {78, "MICROWAVE"},
{79, "OVEN"}, {80, "TOASTER"}, {81, "SINK"},
{82, "REFRIGERATOR"}, {83, "BLENDER"}, {84, "BOOK"},
{85, "CLOCK"}, {86, "VASE"}, {87, "SCISSORS"},
{88, "TEDDY_BEAR"}, {89, "HAIR_DRIER"}, {90, "TOOTHBRUSH"},
{91, "HAIR_BRUSH"}};
8 changes: 4 additions & 4 deletions ios/RnExecutorch/utils/ImageProcessor.mm
chmjkb marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat {
int row = i / mat.cols;
int col = i % mat.cols;
cv::Vec3b pixel = mat.at<cv::Vec3b>(row, col);
floatArray[i] = @(pixel[2] / 255.0f);
floatArray[pixelCount + i] = @(pixel[1] / 255.0f);
floatArray[0 * pixelCount + i] = @(pixel[2] / 255.0f);
floatArray[1 * pixelCount + i] = @(pixel[1] / 255.0f);
floatArray[2 * pixelCount + i] = @(pixel[0] / 255.0f);
}

Expand All @@ -31,8 +31,8 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat {
int col = i % width;
float r = 0, g = 0, b = 0;

r = [[array objectAtIndex: i] floatValue];
g = [[array objectAtIndex: pixelCount + i] floatValue];
r = [[array objectAtIndex: 0 * pixelCount + i] floatValue];
g = [[array objectAtIndex: 1 * pixelCount + i] floatValue];
b = [[array objectAtIndex: 2 * pixelCount + i] floatValue];

cv::Vec3b color((uchar)(b * 255), (uchar)(g * 255), (uchar)(r * 255));
Expand Down
23 changes: 23 additions & 0 deletions ios/RnExecutorch/utils/ObjectDetectionUtils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef ObjectDetectionUtils_hpp
#define ObjectDetectionUtils_hpp

#import <Foundation/Foundation.h>
#include <stdio.h>
#include <vector>

struct Detection {
float x1;
float y1;
float x2;
float y2;
float label;
float score;
};

NSString *floatLabelToNSString(float label);
NSDictionary *detectionToNSDictionary(const Detection &detection);
float iou(const Detection &a, const Detection &b);
std::vector<Detection> nms(std::vector<Detection> detections,
float iouThreshold);

#endif /* ObjectDetectionUtils_hpp */
84 changes: 84 additions & 0 deletions ios/RnExecutorch/utils/ObjectDetectionUtils.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include "ObjectDetectionUtils.hpp"
#include "Constants.h"
#include <map>
#include <vector>

NSString *floatLabelToNSString(float label) {
int intLabel = static_cast<int>(label);
auto it = cocoLabelsMap.find(intLabel);
if (it != cocoLabelsMap.end()) {
return [NSString stringWithUTF8String:it->second.c_str()];
} else {
return [NSString stringWithUTF8String:"unknown"];
}
}

NSDictionary *detectionToNSDictionary(const Detection &detection) {
return @{
@"bbox" : @{
@"x1" : @(detection.x1),
@"y1" : @(detection.y1),
@"x2" : @(detection.x2),
@"y2" : @(detection.y2),
},
@"label" : floatLabelToNSString(detection.label),
@"score" : @(detection.score)
};
}

float iou(const Detection &a, const Detection &b) {
float x1 = std::max(a.x1, b.x1);
float y1 = std::max(a.y1, b.y1);
float x2 = std::min(a.x2, b.x2);
float y2 = std::min(a.y2, b.y2);

float intersectionArea = std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1);
float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
float unionArea = areaA + areaB - intersectionArea;

return intersectionArea / unionArea;
};

std::vector<Detection> nms(std::vector<Detection> detections,
float iouThreshold) {
if (detections.empty()) {
return {};
}

// Sort by label, then by score
std::sort(detections.begin(), detections.end(),
[](const Detection &a, const Detection &b) {
if (a.label == b.label) {
return a.score > b.score;
}
return a.label < b.label;
});

std::vector<Detection> result;
// Apply NMS for each label
for (size_t i = 0; i < detections.size();) {
float currentLabel = detections[i].label;

std::vector<Detection> labelDetections;
while (i < detections.size() && detections[i].label == currentLabel) {
labelDetections.push_back(detections[i]);
++i;
}

std::vector<Detection> filteredLabelDetections;
while (!labelDetections.empty()) {
Detection current = labelDetections.front();
filteredLabelDetections.push_back(current);
labelDetections.erase(
std::remove_if(labelDetections.begin(), labelDetections.end(),
[&](const Detection &other) {
return iou(current, other) > iouThreshold;
}),
labelDetections.end());
}
result.insert(result.end(), filteredLabelDetections.begin(),
filteredLabelDetections.end());
}
return result;
}
2 changes: 2 additions & 0 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ export * from './ETModule';
export * from './LLM';
export * from './StyleTransfer';
export * from './constants/modelUrls';
export * from './models/object_detection/ObjectDetection';
export * from './models/object_detection/types';
63 changes: 63 additions & 0 deletions src/models/object_detection/ObjectDetection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { useEffect, useState } from 'react';
import { Image } from 'react-native';
import { ETError, getError } from '../../Error';
import { ObjectDetection } from '../../native/RnExecutorchModules';
import { Detection } from './types';

interface Props {
modelSource: string | number;
}

interface ObjectDetectionModule {
error: string | null;
isModelReady: boolean;
isModelGenerating: boolean;
forward: (input: string) => Promise<Detection[]>;
}

export const useObjectDetection = ({
modelSource,
}: Props): ObjectDetectionModule => {
const [error, setError] = useState<null | string>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modelSource;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
setIsModelReady(false);
await ObjectDetection.loadModule(path);
setIsModelReady(true);
} catch (e) {
setError(getError(e));
}
};

loadModel();
}, [modelSource]);

const forward = async (input: string) => {
if (!isModelReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isModelGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
try {
setIsModelGenerating(true);
const output = await ObjectDetection.forward(input);
return output;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsModelGenerating(false);
}
};

return { error, isModelReady, isModelGenerating, forward };
};
Loading
Loading