Skip to content

Commit

Permalink
adap to new jdll
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 24, 2024
1 parent 1de81ee commit 453786d
Showing 1 changed file with 20 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.download.DownloadModel;
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
import io.bioimage.modelrunner.engine.EngineInfo;
Expand All @@ -45,6 +46,7 @@
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
Expand All @@ -56,15 +58,13 @@
import java.net.URL;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileAlreadyExistsException;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import org.tensorflow.SavedModelBundle;
Expand Down Expand Up @@ -92,39 +92,6 @@
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
*/
public class Tensorflow2Interface implements DeepLearningEngineInterface {

private static final String[] MODEL_TAGS = { "serve", "inference", "train",
"eval", "gpu", "tpu" };

private static final String[] TF_MODEL_TAGS = {
"tf.saved_model.tag_constants.SERVING",
"tf.saved_model.tag_constants.INFERENCE",
"tf.saved_model.tag_constants.TRAINING",
"tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU",
"tf.saved_model.tag_constants.TPU" };

private static final String[] SIGNATURE_CONSTANTS = { "serving_default",
"inputs", "tensorflow/serving/classify", "classes", "scores", "inputs",
"tensorflow/serving/predict", "outputs", "inputs",
"tensorflow/serving/regress", "outputs", "train", "eval",
"tensorflow/supervised/training", "tensorflow/supervised/eval" };

private static final String[] TF_SIGNATURE_CONSTANTS = {
"tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY",
"tf.saved_model.signature_constants.CLASSIFY_INPUTS",
"tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME",
"tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES",
"tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES",
"tf.saved_model.signature_constants.PREDICT_INPUTS",
"tf.saved_model.signature_constants.PREDICT_METHOD_NAME",
"tf.saved_model.signature_constants.PREDICT_OUTPUTS",
"tf.saved_model.signature_constants.REGRESS_INPUTS",
"tf.saved_model.signature_constants.REGRESS_METHOD_NAME",
"tf.saved_model.signature_constants.REGRESS_OUTPUTS",
"tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY",
"tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY",
"tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME",
"tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" };
/**
* Name without vesion of the JAR created for this library
*/
Expand Down Expand Up @@ -246,7 +213,7 @@ private void checkModelUnzipped() throws LoadModelException, IOException, Except
if (new File(modelFolder, "variables").isDirectory()
&& new File(modelFolder, "saved_model.pb").isFile())
return;
unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME));
unzipTfWeights(ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME));
}

/**
Expand Down Expand Up @@ -278,7 +245,8 @@ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelExceptio
* and modifies the output list with the results obtained
*/
@Override
public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
throws RunModelException
{
if (interprocessing) {
Expand Down Expand Up @@ -361,7 +329,8 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
* expected results of the model
* @throws RunModelException if there is any issue running the model
*/
public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException {
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
shmaInputList = new ArrayList<SharedMemoryArray>();
shmaOutputList = new ArrayList<SharedMemoryArray>();
List<String> encIns = modifyForWinCmd(encodeInputs(inputTensors));
Expand Down Expand Up @@ -420,10 +389,10 @@ private static List<String> modifyForWinCmd(List<String> ins){
}


private List<String> encodeInputs(List<Tensor<?>> inputTensors) {
private <T extends RealType<T> & NativeType<T>> List<String> encodeInputs(List<Tensor<T>> inputTensors) {
List<String> encodedInputTensors = new ArrayList<String>();
Gson gson = new Gson();
for (Tensor<?> tt : inputTensors) {
for (Tensor<T> tt : inputTensors) {
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
shmaInputList.add(shma);
HashMap<String, Object> map = new HashMap<String, Object>();
Expand All @@ -438,7 +407,8 @@ private List<String> encodeInputs(List<Tensor<?>> inputTensors) {
}


private List<String> encodeOutputs(List<Tensor<?>> outputTensors) {
private <T extends RealType<T> & NativeType<T>>
List<String> encodeOutputs(List<Tensor<T>> outputTensors) {
Gson gson = new Gson();
List<String> encodedOutputTensors = new ArrayList<String>();
for (Tensor<?> tt : outputTensors) {
Expand Down Expand Up @@ -474,8 +444,9 @@ private List<String> encodeOutputs(List<Tensor<?>> outputTensors) {
* @throws RunModelException If the number of tensors expected is not the same
* as the number of Tensors outputed by the model
*/
public static void fillOutputTensors(
List<org.tensorflow.Tensor> outputTfTensors, List<Tensor<?>> outputTensors)
public static <T extends RealType<T> & NativeType<T>>
void fillOutputTensors(
List<org.tensorflow.Tensor> outputTfTensors, List<Tensor<T>> outputTensors)
throws RunModelException
{
if (outputTfTensors.size() != outputTensors.size())
Expand Down Expand Up @@ -732,19 +703,20 @@ private static String getEnginesDir() {
* @throws RunModelException if there is any error running the model
* @throws URISyntaxException
*/
public static void main(String[] args) throws LoadModelException, IOException, RunModelException, URISyntaxException {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void main(String[] args) throws LoadModelException, IOException, RunModelException, URISyntaxException {

String modelFolder = "/home/carlos/Desktop/Fiji.app/models/model_03bioimageio";
String modelSourc = modelFolder + "/weights-torchscript.pt";
Tensorflow2Interface pi = new Tensorflow2Interface();
try {
pi.loadModel(modelFolder, modelSourc);
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(new long[] {1, 512, 512, 1});
Tensor<?> inp = Tensor.build("aa", "byxc", rai);
Tensor<?> out = Tensor.buildBlankTensor("oo", "bcyx", new long[] {1, 512, 512, 33}, new FloatType());
Tensor<T> inp = (Tensor<T>) Tensor.build("aa", "byxc", rai);
Tensor<R> out = (Tensor<R>) Tensor.buildBlankTensor("oo", "bcyx", new long[] {1, 512, 512, 33}, new FloatType());
//Tensor<?> out = Tensor.buildEmptyTensor("oo", "byxc");
List<Tensor<?>> ins = new ArrayList<Tensor<?>>();
List<Tensor<?>> ous = new ArrayList<Tensor<?>>();
List<Tensor<T>> ins = new ArrayList<Tensor<T>>();
List<Tensor<R>> ous = new ArrayList<Tensor<R>>();
ins.add(inp);
ous.add(out);
pi.run(ins, ous);
Expand Down

0 comments on commit 453786d

Please sign in to comment.