From f527fbc61de4fe8674034532025fab396b571202 Mon Sep 17 00:00:00 2001 From: jameszow Date: Tue, 1 Aug 2023 21:05:52 +0800 Subject: [PATCH 1/8] update version 0.4.0 -> 0.5.0 --- tensorflow-examples/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-examples/pom.xml b/tensorflow-examples/pom.xml index 5feefd3..46bb1df 100644 --- a/tensorflow-examples/pom.xml +++ b/tensorflow-examples/pom.xml @@ -12,7 +12,7 @@ 1.8 1.8 - 0.4.0 + 0.5.0 From 8b242c459b77d513c2114b0d95b5aaebc45a1a63 Mon Sep 17 00:00:00 2001 From: jameszow Date: Tue, 1 Aug 2023 21:06:45 +0800 Subject: [PATCH 2/8] Fix compilation errors caused by modified versions --- .../cnn/fastrcnn/FasterRcnnInception.java | 30 ++++++++++++------- .../linear/LinearRegressionExample.java | 3 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index 8395969..fb9fb89 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -105,11 +105,8 @@ The given SavedModel SignatureDef contains the following output(s): import java.util.HashMap; import java.util.Map; import java.util.TreeMap; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.Tensor; + +import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -227,8 +224,7 @@ public class FasterRcnnInception { "hair brush" }; - public static void main(String[] params) { - + private static void rcnnInception(String [] params) { if (params.length != 2) { throw new IllegalArgumentException("Exactly 2 parameters required !"); } @@ -269,7 +265,9 @@ public static void main(String[] params) { //The given SavedModel SignatureDef input feedDict.put("input_tensor", reshapeTensor); //The given SavedModel MetaGraphDef key - Map outputTensorMap = model.function("serving_default").call(feedDict); + Map outputTensorMap = new HashMap<>(); + // model.function("serving_default").call(feedDict); + //detection_classes, detectionBoxes etc. are model output names try (TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes"); TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes"); @@ -320,9 +318,9 @@ public static void main(String[] params) { tf.dtypes.cast(tf.reshape( tf.math.mul( tf.image.drawBoundingBoxes(tf.math.div( - tf.dtypes.cast(tf.constant(reshapeTensor), - TFloat32.class), - tf.constant(255.0f) + tf.dtypes.cast(tf.constant(reshapeTensor), + TFloat32.class), + tf.constant(255.0f) ), boxesPlaceHolder, colors), tf.constant(255.0f) @@ -344,4 +342,14 @@ public static void main(String[] params) { } } } + + /** + * 1. test image path + * 2. output image path + * @param params input param + */ + public static void main(String[] params) { + String [] input = {"", ""}; + rcnnInception(input); + } } diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java index 4e8fbd5..c0be7dc 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Random; import org.tensorflow.Graph; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.framework.optimizers.GradientDescent; import org.tensorflow.framework.optimizers.Optimizer; @@ -108,7 +109,7 @@ public static void main(String[] args) { } // Extract linear regression model weight and bias values - List tensorList = session.runner() + Result tensorList = session.runner() .fetch(WEIGHT_VARIABLE_NAME) .fetch(BIAS_VARIABLE_NAME) .run(); From e2f10130d3b2f258322a3b2d6a42cb9a6264717b Mon Sep 17 00:00:00 2001 From: jameszow Date: Tue, 1 Aug 2023 21:15:49 +0800 Subject: [PATCH 3/8] Handling Code Formatting --- .../model/examples/cnn/fastrcnn/FasterRcnnInception.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index fb9fb89..63938ee 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -224,7 +224,7 @@ public class FasterRcnnInception { "hair brush" }; - private static void rcnnInception(String [] params) { + private static void rcnnInception(String[] params) { if (params.length != 2) { throw new IllegalArgumentException("Exactly 2 parameters required !"); } @@ -266,7 +266,7 @@ private static void rcnnInception(String [] params) { feedDict.put("input_tensor", reshapeTensor); //The given SavedModel MetaGraphDef key Map outputTensorMap = new HashMap<>(); - // model.function("serving_default").call(feedDict); + model.function("serving_default").call(feedDict); //detection_classes, detectionBoxes etc. are model output names try (TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes"); @@ -346,10 +346,11 @@ private static void rcnnInception(String [] params) { /** * 1. test image path * 2. output image path + * * @param params input param */ public static void main(String[] params) { - String [] input = {"", ""}; + String[] input = {"", ""}; rcnnInception(input); } } From 05061cfb6561bdd12687aa7ba6d59dd280fd055c Mon Sep 17 00:00:00 2001 From: jameszow Date: Tue, 1 Aug 2023 21:21:35 +0800 Subject: [PATCH 4/8] Handling Code Formatting --- .../regression/linear/LinearRegressionExample.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java index c0be7dc..4c044f0 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java @@ -16,7 +16,6 @@ */ package org.tensorflow.model.examples.regression.linear; -import java.util.List; import java.util.Random; import org.tensorflow.Graph; import org.tensorflow.Result; @@ -114,8 +113,8 @@ public static void main(String[] args) { .fetch(BIAS_VARIABLE_NAME) .run(); - try (TFloat32 weightValue = (TFloat32)tensorList.get(0); - TFloat32 biasValue = (TFloat32)tensorList.get(1)) { + try (TFloat32 weightValue = (TFloat32) tensorList.get(0); + TFloat32 biasValue = (TFloat32) tensorList.get(1)) { System.out.println("Weight is " + weightValue.getFloat()); System.out.println("Bias is " + biasValue.getFloat()); @@ -127,7 +126,7 @@ public static void main(String[] args) { try (TFloat32 xTensor = TFloat32.scalarOf(x); TFloat32 yTensor = TFloat32.scalarOf(predictedY); - TFloat32 yPredictedTensor = (TFloat32)session.runner() + TFloat32 yPredictedTensor = (TFloat32) session.runner() .feed(xData.asOutput(), xTensor) .feed(yData.asOutput(), yTensor) .fetch(yPredicted) From af3001fd343897f3e8df8ef766d9d400da9eb855 Mon Sep 17 00:00:00 2001 From: James Zow Date: Thu, 3 Aug 2023 00:10:04 +0800 Subject: [PATCH 5/8] fix 2023-08-02 --- .../cnn/fastrcnn/FasterRcnnInception.java | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index 63938ee..9db2eff 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -101,12 +101,13 @@ The given SavedModel SignatureDef contains the following output(s): */ -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.TreeMap; +import java.util.*; -import org.tensorflow.*; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -224,7 +225,7 @@ public class FasterRcnnInception { "hair brush" }; - private static void rcnnInception(String[] params) { + private static void main(String[] params) { if (params.length != 2) { throw new IllegalArgumentException("Exactly 2 parameters required !"); } @@ -266,8 +267,7 @@ private static void rcnnInception(String[] params) { feedDict.put("input_tensor", reshapeTensor); //The given SavedModel MetaGraphDef key Map outputTensorMap = new HashMap<>(); - model.function("serving_default").call(feedDict); - + outputTensorMap.put("module_output", model.function("serving_default").call(feedDict).get(0)); //detection_classes, detectionBoxes etc. are model output names try (TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes"); TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes"); @@ -342,15 +342,4 @@ private static void rcnnInception(String[] params) { } } } - - /** - * 1. test image path - * 2. output image path - * - * @param params input param - */ - public static void main(String[] params) { - String[] input = {"", ""}; - rcnnInception(input); - } -} +} \ No newline at end of file From 139ee269ba80531e48d385b3483c0e7803a0890a Mon Sep 17 00:00:00 2001 From: James Zow Date: Thu, 3 Aug 2023 00:28:39 +0800 Subject: [PATCH 6/8] fix Import all util packages --- .../model/examples/cnn/fastrcnn/FasterRcnnInception.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index 9db2eff..85e0af7 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -99,10 +99,10 @@ The given SavedModel SignatureDef contains the following output(s): detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file. but again the actual tensor is DT_FLOAT according to saved_model_cli. */ - - -import java.util.*; - +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.SavedModelBundle; From 50b5da1d5b7c24987b929ff5549f02abb4fcf104 Mon Sep 17 00:00:00 2001 From: jameszow Date: Fri, 4 Aug 2023 17:10:08 +0800 Subject: [PATCH 7/8] Output Map update -> Result Object --- .../cnn/fastrcnn/FasterRcnnInception.java | 140 +++++++++--------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index 85e0af7..9e1f6ed 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -99,15 +99,18 @@ The given SavedModel SignatureDef contains the following output(s): detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file. but again the actual tensor is DT_FLOAT according to saved_model_cli. */ + import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.TreeMap; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.Result; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -225,16 +228,17 @@ public class FasterRcnnInception { "hair brush" }; - private static void main(String[] params) { + public static void main(String[] params) { if (params.length != 2) { throw new IllegalArgumentException("Exactly 2 parameters required !"); } + //my output image String outputImagePath = params[1]; //my test image String imagePath = params[0]; // get path to model folder - String modelPath = "models/faster_rcnn_inception_resnet_v2_1024x1024"; + String modelPath = "models/faster_rcnn_inception_resnet_v2_1024x1024_1"; // load saved model SavedModelBundle model = SavedModelBundle.load(modelPath, "serve"); //create a map of the COCO 2017 labels @@ -266,77 +270,71 @@ private static void main(String[] params) { //The given SavedModel SignatureDef input feedDict.put("input_tensor", reshapeTensor); //The given SavedModel MetaGraphDef key - Map outputTensorMap = new HashMap<>(); - outputTensorMap.put("module_output", model.function("serving_default").call(feedDict).get(0)); - //detection_classes, detectionBoxes etc. are model output names - try (TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes"); - TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes"); - TFloat32 rawDetectionBoxes = (TFloat32) outputTensorMap.get("raw_detection_boxes"); - TFloat32 numDetections = (TFloat32) outputTensorMap.get("num_detections"); - TFloat32 detectionScores = (TFloat32) outputTensorMap.get("detection_scores"); - TFloat32 rawDetectionScores = (TFloat32) outputTensorMap.get("raw_detection_scores"); - TFloat32 detectionAnchorIndices = (TFloat32) outputTensorMap.get("detection_anchor_indices"); - TFloat32 detectionMulticlassScores = (TFloat32) outputTensorMap.get("detection_multiclass_scores")) { - int numDetects = (int) numDetections.getFloat(0); - if (numDetects > 0) { - ArrayList boxArray = new ArrayList<>(); - //TODO tf.image.combinedNonMaxSuppression - for (int n = 0; n < numDetects; n++) { - //put probability and position in outputMap - float detectionScore = detectionScores.getFloat(0, n); - //only include those classes with detection score greater than 0.3f - if (detectionScore > 0.3f) { - boxArray.add(detectionBoxes.get(0, n)); - } + Result result = model.function("serving_default").call(feedDict); + //detection_classes, detectionBoxes, num_detections. are model output names + TFloat32 detectionBoxes = (TFloat32) result.get("detection_boxes").orElseThrow(() -> new RuntimeException("model output exception detection_classes key is null")); + TFloat32 numDetections = (TFloat32) result.get("num_detections").orElseThrow(() -> new RuntimeException("model output exception num_detections key is null")); + TFloat32 detectionScores = (TFloat32) result.get("detection_scores").orElseThrow(() -> new RuntimeException("model output exception detection_scores key is null")); + + int numDetects = (int) numDetections.getFloat(0); + if (numDetects > 0) { + ArrayList boxArray = new ArrayList<>(); + //TODO tf.image.combinedNonMaxSuppression + for (int n = 0; n < numDetects; n++) { + //put probability and position in outputMap + float detectionScore = detectionScores.getFloat(0, n); + //only include those classes with detection score greater than 0.3f + if (detectionScore > 0.3f) { + boxArray.add(detectionBoxes.get(0, n)); } - //2-D. A list of RGBA colors to cycle through for the boxes. - Operand colors = tf.constant(new float[][]{ - {0.9f, 0.3f, 0.3f, 0.0f}, - {0.3f, 0.3f, 0.9f, 0.0f}, - {0.3f, 0.9f, 0.3f, 0.0f} - }); - Shape boxesShape = Shape.of(1, boxArray.size(), 4); - int boxCount = 0; - //3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding boxes - try (TFloat32 boxes = TFloat32.tensorOf(boxesShape)) { - //batch size of 1 - boxes.setFloat(1, 0, 0, 0); - for (FloatNdArray floatNdArray : boxArray) { - boxes.set(floatNdArray, 0, boxCount); - boxCount++; - } - //Placeholders for boxes and path to outputimage - Placeholder boxesPlaceHolder = tf.placeholder(TFloat32.class, Placeholder.shape(boxesShape)); - Placeholder outImagePathPlaceholder = tf.placeholder(TString.class); - //Create JPEG from the Tensor with quality of 100% - EncodeJpeg.Options jpgOptions = EncodeJpeg.quality(100L); - //convert the 4D input image to normalised 0.0f - 1.0f - //Draw bounding boxes using boxes tensor and list of colors - //multiply by 255 then reshape and recast to TUint8 3D tensor - WriteFile writeFile = tf.io.writeFile(outImagePathPlaceholder, - tf.image.encodeJpeg( - tf.dtypes.cast(tf.reshape( - tf.math.mul( - tf.image.drawBoundingBoxes(tf.math.div( - tf.dtypes.cast(tf.constant(reshapeTensor), - TFloat32.class), - tf.constant(255.0f) - ), - boxesPlaceHolder, colors), - tf.constant(255.0f) - ), - tf.array( - imageShape.asArray()[0], - imageShape.asArray()[1], - imageShape.asArray()[2] - ) - ), TUint8.class), - jpgOptions)); - //output the JPEG to file - s.runner().feed(outImagePathPlaceholder, TString.scalarOf(outputImagePath)) - .feed(boxesPlaceHolder, boxes) - .addTarget(writeFile).run(); + } + //2-D. A list of RGBA colors to cycle through for the boxes. + Operand colors = tf.constant(new float[][]{ + {0.9f, 0.3f, 0.3f, 0.0f}, + {0.3f, 0.3f, 0.9f, 0.0f}, + {0.3f, 0.9f, 0.3f, 0.0f} + }); + Shape boxesShape = Shape.of(1, boxArray.size(), 4); + int boxCount = 0; + //3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding boxes + try (TFloat32 boxes = TFloat32.tensorOf(boxesShape)) { + //batch size of 1 + boxes.setFloat(1, 0, 0, 0); + for (FloatNdArray floatNdArray : boxArray) { + boxes.set(floatNdArray, 0, boxCount); + boxCount++; } + //Placeholders for boxes and path to outputimage + Placeholder boxesPlaceHolder = tf.placeholder(TFloat32.class, Placeholder.shape(boxesShape)); + Placeholder outImagePathPlaceholder = tf.placeholder(TString.class); + //Create JPEG from the Tensor with quality of 100% + EncodeJpeg.Options jpgOptions = EncodeJpeg.quality(100L); + //convert the 4D input image to normalised 0.0f - 1.0f + //Draw bounding boxes using boxes tensor and list of colors + //multiply by 255 then reshape and recast to TUint8 3D tensor + WriteFile writeFile = tf.io.writeFile(outImagePathPlaceholder, + tf.image.encodeJpeg( + tf.dtypes.cast(tf.reshape( + tf.math.mul( + tf.image.drawBoundingBoxes(tf.math.div( + tf.dtypes.cast(tf.constant(reshapeTensor), + TFloat32.class), + tf.constant(255.0f) + ), + boxesPlaceHolder, colors), + tf.constant(255.0f) + ), + tf.array( + imageShape.asArray()[0], + imageShape.asArray()[1], + imageShape.asArray()[2] + ) + ), TUint8.class), + jpgOptions)); + //output the JPEG to file + s.runner().feed(outImagePathPlaceholder, TString.scalarOf(outputImagePath)) + .feed(boxesPlaceHolder, boxes) + .addTarget(writeFile).run(); } } } From 5562984005d09f5f590830e49e112b521b77a017 Mon Sep 17 00:00:00 2001 From: James Zow Date: Fri, 4 Aug 2023 22:49:36 +0800 Subject: [PATCH 8/8] result add try-with-resources close outputs --- .../cnn/fastrcnn/FasterRcnnInception.java | 128 +++++++++--------- 1 file changed, 65 insertions(+), 63 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index 9e1f6ed..0cc7712 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -269,72 +269,74 @@ public static void main(String[] params) { Map feedDict = new HashMap<>(); //The given SavedModel SignatureDef input feedDict.put("input_tensor", reshapeTensor); - //The given SavedModel MetaGraphDef key - Result result = model.function("serving_default").call(feedDict); //detection_classes, detectionBoxes, num_detections. are model output names - TFloat32 detectionBoxes = (TFloat32) result.get("detection_boxes").orElseThrow(() -> new RuntimeException("model output exception detection_classes key is null")); - TFloat32 numDetections = (TFloat32) result.get("num_detections").orElseThrow(() -> new RuntimeException("model output exception num_detections key is null")); - TFloat32 detectionScores = (TFloat32) result.get("detection_scores").orElseThrow(() -> new RuntimeException("model output exception detection_scores key is null")); - - int numDetects = (int) numDetections.getFloat(0); - if (numDetects > 0) { - ArrayList boxArray = new ArrayList<>(); - //TODO tf.image.combinedNonMaxSuppression - for (int n = 0; n < numDetects; n++) { - //put probability and position in outputMap - float detectionScore = detectionScores.getFloat(0, n); - //only include those classes with detection score greater than 0.3f - if (detectionScore > 0.3f) { - boxArray.add(detectionBoxes.get(0, n)); + try (Result result = model.function("serving_default").call(feedDict); + TFloat32 detectionBoxes = (TFloat32) result.get("detection_boxes") + .orElseThrow(() -> new RuntimeException("model output exception detection_boxes key is null")); + TFloat32 numDetections = (TFloat32) result.get("num_detections") + .orElseThrow(() -> new RuntimeException("model output exception num_detections key is null")); + TFloat32 detectionScores = (TFloat32) result.get("detection_scores") + .orElseThrow(() -> new RuntimeException("model output exception detection_scores key is null"))) { + int numDetects = (int) numDetections.getFloat(0); + if (numDetects > 0) { + ArrayList boxArray = new ArrayList<>(); + //TODO tf.image.combinedNonMaxSuppression + for (int n = 0; n < numDetects; n++) { + //put probability and position in outputMap + float detectionScore = detectionScores.getFloat(0, n); + //only include those classes with detection score greater than 0.3f + if (detectionScore > 0.3f) { + boxArray.add(detectionBoxes.get(0, n)); + } } - } - //2-D. A list of RGBA colors to cycle through for the boxes. - Operand colors = tf.constant(new float[][]{ - {0.9f, 0.3f, 0.3f, 0.0f}, - {0.3f, 0.3f, 0.9f, 0.0f}, - {0.3f, 0.9f, 0.3f, 0.0f} - }); - Shape boxesShape = Shape.of(1, boxArray.size(), 4); - int boxCount = 0; - //3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding boxes - try (TFloat32 boxes = TFloat32.tensorOf(boxesShape)) { - //batch size of 1 - boxes.setFloat(1, 0, 0, 0); - for (FloatNdArray floatNdArray : boxArray) { - boxes.set(floatNdArray, 0, boxCount); - boxCount++; + //2-D. A list of RGBA colors to cycle through for the boxes. + Operand colors = tf.constant(new float[][]{ + {0.9f, 0.3f, 0.3f, 0.0f}, + {0.3f, 0.3f, 0.9f, 0.0f}, + {0.3f, 0.9f, 0.3f, 0.0f} + }); + Shape boxesShape = Shape.of(1, boxArray.size(), 4); + int boxCount = 0; + //3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding boxes + try (TFloat32 boxes = TFloat32.tensorOf(boxesShape)) { + //batch size of 1 + boxes.setFloat(1, 0, 0, 0); + for (FloatNdArray floatNdArray : boxArray) { + boxes.set(floatNdArray, 0, boxCount); + boxCount++; + } + //Placeholders for boxes and path to outputimage + Placeholder boxesPlaceHolder = tf.placeholder(TFloat32.class, Placeholder.shape(boxesShape)); + Placeholder outImagePathPlaceholder = tf.placeholder(TString.class); + //Create JPEG from the Tensor with quality of 100% + EncodeJpeg.Options jpgOptions = EncodeJpeg.quality(100L); + //convert the 4D input image to normalised 0.0f - 1.0f + //Draw bounding boxes using boxes tensor and list of colors + //multiply by 255 then reshape and recast to TUint8 3D tensor + WriteFile writeFile = tf.io.writeFile(outImagePathPlaceholder, + tf.image.encodeJpeg( + tf.dtypes.cast(tf.reshape( + tf.math.mul( + tf.image.drawBoundingBoxes(tf.math.div( + tf.dtypes.cast(tf.constant(reshapeTensor), + TFloat32.class), + tf.constant(255.0f) + ), + boxesPlaceHolder, colors), + tf.constant(255.0f) + ), + tf.array( + imageShape.asArray()[0], + imageShape.asArray()[1], + imageShape.asArray()[2] + ) + ), TUint8.class), + jpgOptions)); + //output the JPEG to file + s.runner().feed(outImagePathPlaceholder, TString.scalarOf(outputImagePath)) + .feed(boxesPlaceHolder, boxes) + .addTarget(writeFile).run(); } - //Placeholders for boxes and path to outputimage - Placeholder boxesPlaceHolder = tf.placeholder(TFloat32.class, Placeholder.shape(boxesShape)); - Placeholder outImagePathPlaceholder = tf.placeholder(TString.class); - //Create JPEG from the Tensor with quality of 100% - EncodeJpeg.Options jpgOptions = EncodeJpeg.quality(100L); - //convert the 4D input image to normalised 0.0f - 1.0f - //Draw bounding boxes using boxes tensor and list of colors - //multiply by 255 then reshape and recast to TUint8 3D tensor - WriteFile writeFile = tf.io.writeFile(outImagePathPlaceholder, - tf.image.encodeJpeg( - tf.dtypes.cast(tf.reshape( - tf.math.mul( - tf.image.drawBoundingBoxes(tf.math.div( - tf.dtypes.cast(tf.constant(reshapeTensor), - TFloat32.class), - tf.constant(255.0f) - ), - boxesPlaceHolder, colors), - tf.constant(255.0f) - ), - tf.array( - imageShape.asArray()[0], - imageShape.asArray()[1], - imageShape.asArray()[2] - ) - ), TUint8.class), - jpgOptions)); - //output the JPEG to file - s.runner().feed(outImagePathPlaceholder, TString.scalarOf(outputImagePath)) - .feed(boxesPlaceHolder, boxes) - .addTarget(writeFile).run(); } } }