Skip to content

Commit

Permalink
keep improving the logic for multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jul 22, 2024
1 parent c8ac8a7 commit bdf0c5e
Show file tree
Hide file tree
Showing 4 changed files with 515 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.apposed.appose.Service.RequestType;
import io.bioimage.modelrunner.apposed.appose.Service.ResponseType;

Expand Down Expand Up @@ -78,12 +80,13 @@ private void executeScript(String script, LinkedHashMap<String, Object> inputs)

this.reportLaunch();
try {
if (script.equals("loadModel"))
if (script.equals("loadModel")) {
ti.loadModel((String) inputs.get("modelFolder"), null);
else if (script.equals("inference"))
ti.run(null, null);
else if (script.equals("close"))
} else if (script.equals("inference")) {
ti.runFromShmas((LinkedHashMap<String, Object>) inputs.get("inputs"), (LinkedHashMap<String, Object>) inputs.get("outputs"));
} else if (script.equals("close")) {
ti.closeModel();
}
} catch(Exception ex) {
this.fail(Types.stackTrace(ex.getCause()));
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.ZipUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

import java.io.File;
import java.io.IOException;
Expand All @@ -56,6 +59,8 @@
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import org.tensorflow.SavedModelBundle;
Expand Down Expand Up @@ -280,15 +285,15 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
List<String> inputListNames = new ArrayList<String>();
List<TType> inTensors = new ArrayList<TType>();
int c = 0;
for (Tensor tt : inputTensors) {
for (Tensor<?> tt : inputTensors) {
inputListNames.add(tt.getName());
TType inT = TensorBuilder.build(tt);
inTensors.add(inT);
String inputName = getModelInputName(tt.getName(), c ++);
runner.feed(inputName, inT);
}
c = 0;
for (Tensor tt : outputTensors)
for (Tensor<?> tt : outputTensors)
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
// Run runner
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();
Expand All @@ -305,6 +310,46 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
}
}

protected void runFromShmas(LinkedHashMap<String, Object> inputs, LinkedHashMap<String, Object> outputs) {
Session session = model.session();
Session.Runner runner = session.runner();

List<TType> inTensors = new ArrayList<TType>();
int c = 0;
for (Entry<String, Object> ee : inputs.entrySet()) {
Map<String, Object> decoded = Types.decode((String) ee.getValue());
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
inTensors.add(inT);
String inputName = getModelInputName(ee.getKey(), c ++);
runner.feed(inputName, inT);
}

c = 0;
for (Entry<String, Object> ee : outputs.entrySet())
runner = runner.fetch(getModelOutputName(ee.getKey(), c ++));
// Run runner
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();

// Fill the agnostic output tensors list with data from the inference result
for (Entry<String, Object> ee : outputs.entrySet()) {
Map<String, Object> decoded = Types.decode((String) ee.getValue());
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
inTensors.add(inT);
String inputName = getModelInputName(ee.getKey(), c ++);
runner.feed(inputName, inT);
}
// Close the remaining resources
session.close();
for (TType tt : inTensors) {
tt.close();
}
for (org.tensorflow.Tensor tt : resultPatchTensors) {
tt.close();
}
}

/**
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
* to create another process, communicate the model info and tensors to the other
Expand All @@ -322,9 +367,34 @@ public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> out
modifyForWinCmd(encIns);
LinkedHashMap<String, String> encOuts = encodeOutputs(outputTensors);
modifyForWinCmd(encOuts);
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
args.put("inputs", encIns);
args.put("outputs", encOuts);

try {
Task task = runner.task("inference", args);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
for (int i = 0; i < outputTensors.size(); i ++) {
String name = (String) Types.decode(encOuts.get(outputTensors.get(i).getName())).get(MEM_NAME_KEY);
SharedMemoryArray shm = shmaOutputList.stream()
.filter(ss -> ss.getName().equals(name)).findFirst().orElse(null);
if (shm == null) {
shm = SharedMemoryArray.read(name);
shmaOutputList.add(shm);
}
RandomAccessibleInterval<?> rai = shm.getSharedRAI();
outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
}
} catch (Exception e) {
closeShmas();
if (e instanceof RunModelException)
throw (RunModelException) e;
throw new RunModelException(Types.stackTrace(e));
}
closeShmas();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*-
* #%L
* This project complements the DL-model runner acting as the engine that works loading models
* and making inference with Java 0.3.0 and newer API for Tensorflow 2.
* %%
* Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers.
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package io.bioimage.modelrunner.tensorflow.v2.api030.shm;

import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;

import java.nio.ByteBuffer;
import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TType;

import net.imglib2.type.numeric.integer.UnsignedByteType;

/**
* A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects.
* Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor})
* from Tensorflow 2 {@link Tensor}
*
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
*/
public final class ShmBuilder
{
/**
* Utility class.
*/
private ShmBuilder()
{
}

/**
* Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor
*
* @param <T>
* the possible ImgLib2 datatypes of the image
* @param tensor
* The {@link TType} tensor data is read from.
* @return The {@link RandomAccessibleInterval} built from the {@link TType} tensor.
* @throws IllegalArgumentException If the {@link TType} tensor type is not supported.
*/
public static void build(TType tensor, String memoryName) throws IllegalArgumentException
{
if (tensor instanceof TUint8)
{
buildFromTensorUByte((TUint8) tensor, memoryName);
}
else if (tensor instanceof TInt32)
{
buildFromTensorInt((TInt32) tensor, memoryName);
}
else if (tensor instanceof TFloat32)
{
buildFromTensorFloat((TFloat32) tensor, memoryName);
}
else if (tensor instanceof TFloat64)
{
buildFromTensorDouble((TFloat64) tensor, memoryName);
}
else if (tensor instanceof TInt64)
{
buildFromTensorLong((TInt64) tensor, memoryName);
}
else
{
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
}
}

/**
* Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor.
*
* @param tensor
* The {@link TUint8} tensor data is read from.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
*/
private static void buildFromTensorUByte(TUint8 tensor, String memoryName)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.create(arrayShape, new UnsignedByteType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 1;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
}

/**
* Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor.
*
* @param tensor
* The {@link TInt32} tensor data is read from.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
*/
private static void buildFromTensorInt(TInt32 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
int[] flatArr = new int[totalSize];
tensor.asRawTensor().data().asInts().read(flatArr);
RandomAccessibleInterval<IntType> rai = ArrayImgs.ints(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor.
*
* @param tensor
* The {@link TFloat32} tensor data is read from.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
*/
private static void buildFromTensorFloat(TFloat32 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
float[] flatArr = new float[totalSize];
tensor.asRawTensor().data().asFloats().read(flatArr);
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor.
*
* @param tensor
* The {@link TFloat64} tensor data is read from.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
*/
private static void buildFromTensorDouble(TFloat64 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
double[] flatArr = new double[totalSize];
tensor.asRawTensor().data().asDoubles().read(flatArr);
RandomAccessibleInterval<DoubleType> rai = ArrayImgs.doubles(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor.
*
* @param tensor
* The {@link TInt64} tensor data is read from.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
*/
private static void buildFromTensorLong(TInt64 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
long[] flatArr = new long[totalSize];
tensor.asRawTensor().data().asLongs().read(flatArr);
RandomAccessibleInterval<LongType> rai = ArrayImgs.longs(flatArr, tensorShape);
return Utils.transpose(rai);
}
}
Loading

0 comments on commit bdf0c5e

Please sign in to comment.