Skip to content

Commit

Permalink
fix: multiply the bbox so it matches the original image size
Browse files Browse the repository at this point in the history
  • Loading branch information
chmjkb committed Dec 16, 2024
1 parent f9d4a01 commit 35661b7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
44 changes: 25 additions & 19 deletions ios/RnExecutorch/models/object_detection/SSDLiteLargeModel.mm
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "SSDLiteLargeModel.hpp"
#include "ImageProcessor.h"
#include "../../utils/ObjectDetectionUtils.hpp"
#include "ImageProcessor.h"
#include <vector>

inline float constexpr iouThreshold = 0.55;
Expand All @@ -16,7 +16,9 @@ - (NSArray *)preprocess:(cv::Mat)input {
return modelInput;
}

- (NSArray *)postprocess:(NSArray *)input {
- (NSArray *)postprocess:(NSArray *)input
widthRatio:(float)widthRatio
heightRatio:(float)heightRatio {
NSArray *bboxes = [input objectAtIndex:0];
NSArray *scores = [input objectAtIndex:1];
NSArray *labels = [input objectAtIndex:2];
Expand All @@ -29,10 +31,10 @@ - (NSArray *)postprocess:(NSArray *)input {
if (score < detectionThreshold) {
continue;
}
float x1 = [bboxes[idx * 4] floatValue];
float y1 = [bboxes[idx * 4 + 1] floatValue];
float x2 = [bboxes[idx * 4 + 2] floatValue];
float y2 = [bboxes[idx * 4 + 3] floatValue];
float x1 = [bboxes[idx * 4] floatValue] * widthRatio;
float y1 = [bboxes[idx * 4 + 1] floatValue] * heightRatio;
float x2 = [bboxes[idx * 4 + 2] floatValue] * widthRatio;
float y2 = [bboxes[idx * 4 + 3] floatValue] * heightRatio;

Detection det = {x1, y1, x2, y2, label, score};
detections.push_back(det);
Expand All @@ -48,23 +50,27 @@ - (NSArray *)postprocess:(NSArray *)input {
}

- (NSArray *)runModel:(cv::Mat)input {
cv::Size size = input.size();
int inputImageWidth = size.width;
int inputImageHeight = size.height;
NSArray *modelInput = [self preprocess:input];
NSError *forwardError = nil;
NSArray *inputShape = @[
@(1),
@(3),
@(inputWidth),
@(inputHeight)
];
NSArray *forwardResult = [self forward:modelInput shape:inputShape inputType:@3 error:&forwardError];
NSArray *inputShape = @[ @(1), @(3), @(inputWidth), @(inputHeight) ];
NSArray *forwardResult = [self forward:modelInput
shape:inputShape
inputType:@3
error:&forwardError];
if (forwardError) {
@
throw [NSException exceptionWithName:@"forward_error"
reason:[NSString stringWithFormat:@"%ld",
static_cast<long>(forwardError.code)]
userInfo:nil];
@throw [NSException
exceptionWithName:@"forward_error"
reason:[NSString
stringWithFormat:@"%ld", static_cast<long>(
forwardError.code)]
userInfo:nil];
}
NSArray *output = [self postprocess:forwardResult];
NSArray *output = [self postprocess:forwardResult
widthRatio:inputImageWidth / 320.f
heightRatio:inputImageHeight / 320.f];
return output;
}

Expand Down
12 changes: 7 additions & 5 deletions ios/RnExecutorch/utils/ObjectDetectionUtils.mm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "ObjectDetectionUtils.hpp"
#include "Constants.h"
#include <map>
#include <vector>
#include "Constants.h"

NSString *floatLabelToNSString(float label) {
int intLabel = static_cast<int>(label);
Expand All @@ -15,10 +15,12 @@

NSDictionary *detectionToNSDictionary(const Detection &detection) {
return @{
@"x1" : @(detection.x1),
@"y1" : @(detection.y1),
@"x2" : @(detection.x2),
@"y2" : @(detection.y2),
@"bbox" : @{
@"x1" : @(detection.x1),
@"y1" : @(detection.y1),
@"x2" : @(detection.x2),
@"y2" : @(detection.y2),
},
@"label" : floatLabelToNSString(detection.label),
@"score" : @(detection.score)
};
Expand Down
4 changes: 2 additions & 2 deletions src/native/NativeObjectDetection.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import type { TurboModule } from 'react-native';
import { TurboModuleRegistry } from 'react-native';
import { ObjectDetectionResult } from '../models/object_detection/types';
import { Detection } from '../models/object_detection/types';

export interface Spec extends TurboModule {
loadModule(modelSource: string): Promise<number>;
forward(input: string): Promise<ObjectDetectionResult>;
forward(input: string): Promise<Detection[]>;
}

export default TurboModuleRegistry.get<Spec>('ObjectDetection');

0 comments on commit 35661b7

Please sign in to comment.