Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
abdelaziz-mahdy committed Aug 21, 2024
2 parents 75baf98 + 18976c7 commit 4f225d3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 28 deletions.
4 changes: 2 additions & 2 deletions example/lib/ui/camera_view.dart
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
Stopwatch stopwatch = Stopwatch()..start();

String imageClassification = await _imageModel!
.getCameraImagePrediction(cameraImage, _camFrameRotation);
.getCameraImagePrediction(cameraImage, rotation: _camFrameRotation);
// Stop the stopwatch
stopwatch.stop();
// print("imageClassification $imageClassification");
Expand Down Expand Up @@ -221,7 +221,7 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
List<ResultObjectDetection> objDetect =
await _objectModel!.getCameraImagePrediction(
cameraImage,
_camFrameRotation,
rotation: _camFrameRotation,
minimumScore: 0.3,
iOUThreshold: 0.3,
);
Expand Down
32 changes: 23 additions & 9 deletions lib/image_utils_isolate.dart
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ImageUtilsIsolate {
List<Uint8List>? planes = values[3];
int width = values[4];
int height = values[5];
int? rotation = values[6];
Image? image;
if (imageFormatGroup == ImageFormatGroup.yuv420) {
image = convertYUV420ToImage(
Expand All @@ -50,31 +51,43 @@ class ImageUtilsIsolate {
} else {
image = null;
}

rotation ??= Platform.isAndroid ? 90 : 0;
if (image != null) {
if (Platform.isIOS) {
// ios, default camera image is portrait view
// rotate 270 to the view that top is on the left, bottom is on the right
// image ^4.0.17 error here
image = copyRotate(image, angle: 270);
image = copyRotate(image, angle: rotation);
}
if (Platform.isAndroid) {
image = copyRotate(image, angle: 90);
image = copyRotate(image, angle: rotation);
}
return TransferableTypedData.fromList([encodeJpg(image)]);
}
return null;
}

static List<dynamic> _getParamsBasedOnType(CameraImage cameraImage) {
Uint8List _rotateImageBytes(Uint8List imageBytes, int rotation) {
Image? image = decodeImage(imageBytes);
if (image == null) {
throw Exception("Unable to decode image bytes");
}

Image rotatedImage = copyRotate(image, angle: rotation);
return Uint8List.fromList(encodeJpg(rotatedImage));
}

static List<dynamic> _getParamsBasedOnType(CameraImage cameraImage,
{int? rotation}) {
if (cameraImage.format.group == ImageFormatGroup.yuv420) {
return [
cameraImage.format.group,
cameraImage.planes[1].bytesPerRow,
cameraImage.planes[1].bytesPerPixel ?? 0,
cameraImage.planes.map((e) => e.bytes).toList(),
cameraImage.width,
cameraImage.height
cameraImage.height,
rotation
];
} else if (cameraImage.format.group == ImageFormatGroup.bgra8888) {
return [
Expand All @@ -83,20 +96,21 @@ class ImageUtilsIsolate {
null,
cameraImage.planes.map((e) => e.bytes).toList(),
cameraImage.width,
cameraImage.height
cameraImage.height,
rotation
];
}
// You can add more formats as needed
return [];
}

/// Converts a [CameraImage] in YUV420 format to [Image] in RGB format
static Future<Uint8List?> convertCameraImageToBytes(
CameraImage cameraImage) async {
static Future<Uint8List?> convertCameraImageToBytes(CameraImage cameraImage,
{int? rotation}) async {
await ImageUtilsIsolate.init();

return (await ImageUtilsIsolate.computer.compute(_convertCameraImageToBytes,
param: _getParamsBasedOnType(cameraImage))
param: _getParamsBasedOnType(cameraImage, rotation: rotation))
as TransferableTypedData?)
?.materialize()
.asUint8List();
Expand Down
54 changes: 37 additions & 17 deletions lib/pytorch_lite.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import 'package:pytorch_lite/enums/model_type.dart';
import 'package:pytorch_lite/image_utils_isolate.dart';
import 'package:pytorch_lite/pigeon.dart';
import 'package:collection/collection.dart';
import 'package:image/image.dart' as img;

export 'enums/dtype.dart';
export 'package:pytorch_lite/pigeon.dart';
Expand Down Expand Up @@ -124,6 +125,16 @@ Future<List<String>> _getLabelsTxt(String labelPath) async {
return labelsData.split("\n");
}

Uint8List _rotateImageBytes(Uint8List imageBytes, int rotation) {
img.Image? image = img.decodeImage(imageBytes);
if (image == null) {
throw Exception("Unable to decode image bytes");
}

img.Image rotatedImage = img.copyRotate(image, angle: rotation);
return Uint8List.fromList(img.encodeJpg(rotatedImage));
}

/*
class CustomModel {
final int _index;
Expand Down Expand Up @@ -332,13 +343,13 @@ class ClassificationModel {

/// Retrieves a list of predictions for a camera image.
///
/// Takes a [cameraImage] and [rotation] as input. Optional parameters include [mean], [std],
/// Takes a [cameraImage] as input. Optional parameters include [rotation], [mean], [std],
/// [cameraPreProcessingMethod], and [preProcessingMethod].
/// Returns a [Future] that resolves to a [List] of [double] values representing the predictions.
/// Throws an [Exception] if unable to process the image bytes.
Future<List<double>> getCameraImagePredictionList(
CameraImage cameraImage, int rotation,
{List<double> mean = torchVisionNormMeanRGB,
Future<List<double>> getCameraImagePredictionList(CameraImage cameraImage,
{int? rotation,
List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB,
CameraPreProcessingMethod cameraPreProcessingMethod =
CameraPreProcessingMethod.imageLib,
Expand All @@ -351,6 +362,7 @@ class ClassificationModel {
if (bytes == null) {
throw Exception("Unable to process image bytes");
}

// Retrieve the image predictions for the preprocessed image bytes
return await getImagePredictionList(bytes,
mean: mean, std: std, preProcessingMethod: preProcessingMethod);
Expand All @@ -366,19 +378,21 @@ class ClassificationModel {

/// Retrieves the top prediction label for a camera image.
///
/// Takes a [cameraImage] and [rotation] as input. Optional parameters include [mean], [std],
/// Takes a [cameraImage] as input. Optional parameters include [rotation], [mean], [std],
/// [cameraPreProcessingMethod], and [preProcessingMethod].
/// Returns a [Future] that resolves to a [String] representing the top prediction label.
Future<String> getCameraImagePrediction(CameraImage cameraImage, int rotation,
{List<double> mean = torchVisionNormMeanRGB,
Future<String> getCameraImagePrediction(CameraImage cameraImage,
{int? rotation,
List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB,
CameraPreProcessingMethod cameraPreProcessingMethod =
CameraPreProcessingMethod.imageLib,
PreProcessingMethod preProcessingMethod =
PreProcessingMethod.imageLib}) async {
// Retrieve the prediction list for the camera image
final List<double> prediction = await getCameraImagePredictionList(
cameraImage, rotation,
cameraImage,
rotation: rotation,
mean: mean,
std: std,
cameraPreProcessingMethod: cameraPreProcessingMethod,
Expand All @@ -392,20 +406,22 @@ class ClassificationModel {

/// Retrieves the probabilities of predictions for a camera image.
///
/// Takes a [cameraImage] and [rotation] as input. Optional parameters include [mean], [std],
/// Takes a [cameraImage] as input. Optional parameters include [rotation], [mean], [std],
/// [cameraPreProcessingMethod], and [preProcessingMethod].
/// Returns a [Future] that resolves to a [List] of [double] values representing the prediction probabilities.
Future<List<double>> getCameraImagePredictionProbabilities(
CameraImage cameraImage, int rotation,
{List<double> mean = torchVisionNormMeanRGB,
CameraImage cameraImage,
{int? rotation,
List<double> mean = torchVisionNormMeanRGB,
List<double> std = torchVisionNormSTDRGB,
CameraPreProcessingMethod cameraPreProcessingMethod =
CameraPreProcessingMethod.imageLib,
PreProcessingMethod preProcessingMethod =
PreProcessingMethod.imageLib}) async {
// Retrieve the prediction list for the camera image
final List<double> prediction = await getCameraImagePredictionList(
cameraImage, rotation,
cameraImage,
rotation: rotation,
mean: mean,
std: std,
cameraPreProcessingMethod: cameraPreProcessingMethod,
Expand Down Expand Up @@ -583,8 +599,9 @@ class ModelObjectDetection {
/// The optional parameters [minimumScore], [iOUThreshold], [boxesLimit], [cameraPreProcessingMethod], and [preProcessingMethod]
/// allow customization of the prediction process.
Future<List<ResultObjectDetection>> getCameraImagePredictionList(
CameraImage cameraImage, int rotation,
{double minimumScore = 0.5,
CameraImage cameraImage,
{int? rotation,
double minimumScore = 0.5,
double iOUThreshold = 0.5,
int boxesLimit = 10,
CameraPreProcessingMethod cameraPreProcessingMethod =
Expand All @@ -598,6 +615,7 @@ class ModelObjectDetection {
if (bytes == null) {
throw Exception("Unable to process image bytes");
}

// Get the image prediction list using the converted bytes
return await getImagePredictionList(bytes,
minimumScore: minimumScore,
Expand All @@ -620,16 +638,18 @@ class ModelObjectDetection {
/// The optional parameters [minimumScore], [iOUThreshold], [boxesLimit], [cameraPreProcessingMethod], and [preProcessingMethod]
/// allow customization of the prediction process.
Future<List<ResultObjectDetection>> getCameraImagePrediction(
CameraImage cameraImage, int rotation,
{double minimumScore = 0.5,
CameraImage cameraImage,
{int? rotation,
double minimumScore = 0.5,
double iOUThreshold = 0.5,
int boxesLimit = 10,
CameraPreProcessingMethod cameraPreProcessingMethod =
CameraPreProcessingMethod.imageLib,
PreProcessingMethod preProcessingMethod =
PreProcessingMethod.imageLib}) async {
final List<ResultObjectDetection> prediction =
await getCameraImagePredictionList(cameraImage, rotation,
await getCameraImagePredictionList(cameraImage,
rotation: rotation,
minimumScore: minimumScore,
iOUThreshold: iOUThreshold,
boxesLimit: boxesLimit,
Expand Down

0 comments on commit 4f225d3

Please sign in to comment.