Skip to content

Commit

Permalink
keep correcting bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jul 22, 2024
1 parent 1f83610 commit 75fc740
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.apposed.appose.Service.RequestType;
import io.bioimage.modelrunner.apposed.appose.Service.ResponseType;

Expand Down Expand Up @@ -48,7 +48,7 @@ public static void main(String[] args) {

if (requestType.equals(RequestType.EXECUTE.toString())) {
String script = (String) request.get("script");
LinkedHashMap<String, Object> inputs = (LinkedHashMap<String, Object>) request.get("inputs");
Map<String, Object> inputs = (Map<String, Object>) request.get("inputs");
JavaWorker task = new JavaWorker(uuid, ti);
tasks.put(uuid, task);
task.start(script, inputs);
Expand All @@ -72,29 +72,30 @@ private JavaWorker(String uuid, Tensorflow2Interface ti) {
this.ti = ti;
}

private void executeScript(String script, LinkedHashMap<String, Object> inputs) {
LinkedHashMap<String, Object> binding = new LinkedHashMap<String, Object>();
private void executeScript(String script, Map<String, Object> inputs) {
Map<String, Object> binding = new LinkedHashMap<String, Object>();
binding.put("task", this);
if (inputs != null)
binding.putAll(binding);

this.reportLaunch();
try {
if (script.equals("loadModel")) {
System.out.println((String) inputs.get("modelFolder"));
ti.loadModel((String) inputs.get("modelFolder"), null);
} else if (script.equals("inference")) {
ti.runFromShmas((LinkedHashMap<String, Object>) inputs.get("inputs"), (LinkedHashMap<String, Object>) inputs.get("outputs"));
ti.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
} else if (script.equals("close")) {
ti.closeModel();
}
} catch(Exception ex) {
this.fail(Types.stackTrace(ex.getCause()));
this.fail(Types.stackTrace(ex));
return;
}
this.reportCompletion();
}

private void start(String script, LinkedHashMap<String, Object> inputs) {
private void start(String script, Map<String, Object> inputs) {
new Thread(() -> executeScript(script, inputs), "Appose-" + this.uuid).start();
}

Expand All @@ -120,12 +121,12 @@ private void update(String message, Integer current, Integer maximum) {
this.respond(ResponseType.UPDATE, args);
}

private void respond(ResponseType responseType, LinkedHashMap<String, Object> args) {
LinkedHashMap<String, Object> response = new LinkedHashMap<String, Object>();
private void respond(ResponseType responseType, Map<String, Object> args) {
Map<String, Object> response = new HashMap<String, Object>();
response.put("task", uuid);
response.put("responseType", responseType);
if (args != null)
response.putAll(response);
response.putAll(args);
try {
System.out.println(Types.encode(response));
System.out.flush();
Expand All @@ -139,9 +140,10 @@ private void cancel() {
}

private void fail(String error) {
LinkedHashMap<String, Object> args = null;
System.out.println(error);
Map<String, Object> args = null;
if (error != null) {
args = new LinkedHashMap<String, Object>();
args = new HashMap<String, Object>();
args.put("error", error);
}
respond(ResponseType.FAILURE, args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.ZipUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

Expand Down Expand Up @@ -165,6 +167,7 @@ public Tensorflow2Interface(boolean doInterprocessing) throws IOException, URISy
interprocessing = doInterprocessing;
if (this.interprocessing) {
runner = getRunner();
runner.debug((text) -> System.err.println(text));
}
}

Expand All @@ -178,7 +181,7 @@ private Service getRunner() throws IOException, URISyntaxException {
String[] argArr = new String[args.size()];
args.toArray(argArr);

return new Service(new File(argArr[0]), argArr);
return new Service(new File("."), argArr);
}

/**
Expand Down Expand Up @@ -220,7 +223,7 @@ public void loadModel(String modelFolder, String modelSource)

private void launchModelLoadOnProcess() throws IOException, InterruptedException {
HashMap<String, Object> args = new HashMap<String, Object>();
args.put("mdoelFolder", modelFolder);
args.put("modelFolder", modelFolder);
Task task = runner.task("loadModel", args);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
Expand Down Expand Up @@ -311,31 +314,31 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
}
}

protected void runFromShmas(LinkedHashMap<String, Object> inputs, LinkedHashMap<String, Object> outputs) throws IOException {
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
Session session = model.session();
Session.Runner runner = session.runner();

List<TType> inTensors = new ArrayList<TType>();
int c = 0;
for (Entry<String, Object> ee : inputs.entrySet()) {
Map<String, Object> decoded = Types.decode((String) ee.getValue());
for (String ee : inputs) {
Map<String, Object> decoded = Types.decode(ee);
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
if (PlatformDetection.isWindows()) shma.close();
inTensors.add(inT);
String inputName = getModelInputName(ee.getKey(), c ++);
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
runner.feed(inputName, inT);
}

c = 0;
for (Entry<String, Object> ee : outputs.entrySet())
runner = runner.fetch(getModelOutputName(ee.getKey(), c ++));
for (String ee : outputs)
runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++));
// Run runner
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();

// Fill the agnostic output tensors list with data from the inference result
for (Entry<String, Object> ee : outputs.entrySet()) {
Map<String, Object> decoded = Types.decode((String) ee.getValue());
for (String ee : outputs) {
Map<String, Object> decoded = Types.decode(ee);
ShmBuilder.build((TType) resultPatchTensors.get(c), (String) decoded.get(MEM_NAME_KEY));
}
// Close the remaining resources
Expand All @@ -361,10 +364,8 @@ protected void runFromShmas(LinkedHashMap<String, Object> inputs, LinkedHashMap<
public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException {
shmaInputList = new ArrayList<SharedMemoryArray>();
shmaOutputList = new ArrayList<SharedMemoryArray>();
LinkedHashMap<String, String> encIns = encodeInputs(inputTensors);
modifyForWinCmd(encIns);
LinkedHashMap<String, String> encOuts = encodeOutputs(outputTensors);
modifyForWinCmd(encOuts);
List<String> encIns = modifyForWinCmd(encodeInputs(inputTensors));
List<String> encOuts = modifyForWinCmd(encodeOutputs(outputTensors));
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
args.put("inputs", encIns);
args.put("outputs", encOuts);
Expand All @@ -379,7 +380,7 @@ else if (task.status == TaskStatus.FAILED)
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
for (int i = 0; i < outputTensors.size(); i ++) {
String name = (String) Types.decode(encOuts.get(outputTensors.get(i).getName())).get(MEM_NAME_KEY);
String name = (String) Types.decode(encOuts.get(i)).get(MEM_NAME_KEY);
SharedMemoryArray shm = shmaOutputList.stream()
.filter(ss -> ss.getName().equals(name)).findFirst().orElse(null);
if (shm == null) {
Expand Down Expand Up @@ -409,19 +410,18 @@ private void closeShmas() {
shmaOutputList = null;
}

private static void modifyForWinCmd(LinkedHashMap<String, String> ins){
private static List<String> modifyForWinCmd(List<String> ins){
if (!PlatformDetection.isWindows())
return;
ins.entrySet().forEach(ee ->{
String val = ee.getValue();
String nVal = "\"" + val.replace("\"", "\\\"") + "\"";
ee.setValue(nVal);
});
return ins;
List<String> newIns = new ArrayList<String>();
for (String ii : ins)
newIns.add("\"" + ii.replace("\"", "\\\"") + "\"");
return newIns;
}


private LinkedHashMap<String, String> encodeInputs(List<Tensor<?>> inputTensors) {
LinkedHashMap<String, String> encodedInputTensors = new LinkedHashMap<String, String>();
private List<String> encodeInputs(List<Tensor<?>> inputTensors) {
List<String> encodedInputTensors = new ArrayList<String>();
Gson gson = new Gson();
for (Tensor<?> tt : inputTensors) {
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
Expand All @@ -432,15 +432,15 @@ private LinkedHashMap<String, String> encodeInputs(List<Tensor<?>> inputTensors)
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData()));
map.put(IS_INPUT_KEY, true);
map.put(MEM_NAME_KEY, shma.getName());
encodedInputTensors.put(tt.getName(), gson.toJson(map));
encodedInputTensors.add(gson.toJson(map));
}
return encodedInputTensors;
}


private LinkedHashMap<String, String> encodeOutputs(List<Tensor<?>> outputTensors) {
private List<String> encodeOutputs(List<Tensor<?>> outputTensors) {
Gson gson = new Gson();
LinkedHashMap<String, String>encodedOutputTensors = new LinkedHashMap<String, String>();
List<String> encodedOutputTensors = new ArrayList<String>();
for (Tensor<?> tt : outputTensors) {
HashMap<String, Object> map = new HashMap<String, Object>();
map.put(NAME_KEY, tt.getName());
Expand All @@ -460,7 +460,7 @@ private LinkedHashMap<String, String> encodeOutputs(List<Tensor<?>> outputTensor
map.put(MEM_NAME_KEY, memName);
shmaNamesList.add(memName);
}
encodedOutputTensors.put(tt.getName(), gson.toJson(map));
encodedOutputTensors.add(gson.toJson(map));
}
return encodedOutputTensors;
}
Expand Down Expand Up @@ -580,33 +580,6 @@ public static String getModelOutputName(String outputName, int i) {
}
}


/**
* Methods to run interprocessing and bypass the errors that occur in MacOS intel
* with the compatibility between TF2 and TF1/Pytorch
* This method checks that the arguments are correct, retrieves the input and output
* tensors, loads the model, makes inference with it and finally sends the tensors
* to the original process
*
* @param args
* arguments of the program:
* - Path to the model folder
* - Path to a temporary dir
* - Name of the input 0
* - Name of the input 1
* - ...
* - Name of the output n
* - Name of the output 0
* - Name of the output 1
* - ...
* - Name of the output n
* @throws LoadModelException if there is any error loading the model
* @throws IOException if there is any error reading or writing any file or with the paths
* @throws RunModelException if there is any error running the model
*/
public static void main(String[] args) throws LoadModelException, IOException, RunModelException {
}

/**
* if java bin dir contains any special char, surround it by double quotes
* @param javaBin
Expand Down Expand Up @@ -634,12 +607,7 @@ private List<String> getProcessCommandsWithoutArgs() throws IOException, URISynt
String javaHome = System.getProperty("java.home");
String javaBin = javaHome + File.separator + "bin" + File.separator + "java";

String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
String imglib2Path = getPathFromClass(NativeType.class);
if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
&& !modelrunnerPath.contains(File.pathSeparator)))
modelrunnerPath = System.getProperty("java.class.path");
String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
String classpath = getCurrentClasspath();
ProtectionDomain protectionDomain = Tensorflow2Interface.class.getProtectionDomain();
String codeSource = protectionDomain.getCodeSource().getLocation().getPath();
String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString());
Expand All @@ -649,7 +617,7 @@ private List<String> getProcessCommandsWithoutArgs() throws IOException, URISynt
continue;
classpath += ff.getAbsolutePath() + File.pathSeparator;
}
String className = Tensorflow2Interface.class.getName();
String className = JavaWorker.class.getName();
List<String> command = new LinkedList<String>();
command.add(padSpecialJavaBin(javaBin));
command.add("-cp");
Expand All @@ -658,6 +626,25 @@ private List<String> getProcessCommandsWithoutArgs() throws IOException, URISynt
return command;
}

private static String getCurrentClasspath() throws UnsupportedEncodingException {

String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
String imglib2Path = getPathFromClass(NativeType.class);
String gsonPath = getPathFromClass(Gson.class);
String jnaPath = getPathFromClass(com.sun.jna.Library.class);
String jnaPlatformPath = getPathFromClass(com.sun.jna.platform.FileUtils.class);
if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
&& !modelrunnerPath.contains(File.pathSeparator)))
modelrunnerPath = System.getProperty("java.class.path");
modelrunnerPath = System.getProperty("java.class.path");
String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
classpath = classpath + gsonPath + File.pathSeparator;
classpath = classpath + jnaPath + File.pathSeparator;
classpath = classpath + jnaPlatformPath + File.pathSeparator;

return classpath;
}

/**
* Method that gets the path to the JAR from where a specific class is being loaded
* @param clazz
Expand Down Expand Up @@ -719,4 +706,51 @@ private static String getEnginesDir() {
}
return new File(dir).getParent();
}


/**
* Methods to run interprocessing and bypass the errors that occur in MacOS intel
* with the compatibility between TF2 and TF1/Pytorch
* This method checks that the arguments are correct, retrieves the input and output
* tensors, loads the model, makes inference with it and finally sends the tensors
* to the original process
*
* @param args
* arguments of the program:
* - Path to the model folder
* - Path to a temporary dir
* - Name of the input 0
* - Name of the input 1
* - ...
* - Name of the output n
* - Name of the output 0
* - Name of the output 1
* - ...
* - Name of the output n
* @throws LoadModelException if there is any error loading the model
* @throws IOException if there is any error reading or writing any file or with the paths
* @throws RunModelException if there is any error running the model
* @throws URISyntaxException
*/
public static void main(String[] args) throws LoadModelException, IOException, RunModelException, URISyntaxException {

String modelFolder = "/home/carlos/Desktop/Fiji.app/models/model_03bioimageio";
String modelSourc = modelFolder + "/weights-torchscript.pt";
Tensorflow2Interface pi = new Tensorflow2Interface();
try {
pi.loadModel(modelFolder, modelSourc);
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(new long[] {1, 512, 512, 1});
Tensor<?> inp = Tensor.build("aa", "byxc", rai);
//Tensor<?> out = Tensor.buildBlankTensor("oo", "bcyx", new long[] {1, 2, 512, 512}, new FloatType());
Tensor<?> out = Tensor.buildEmptyTensor("oo", "byxc");
List<Tensor<?>> ins = new ArrayList<Tensor<?>>();
List<Tensor<?>> ous = new ArrayList<Tensor<?>>();
ins.add(inp);
ous.add(out);
pi.run(ins, ous);
System.out.println(false);
} catch (Exception ex) {
pi.closeModel();
}
}
}

0 comments on commit 75fc740

Please sign in to comment.