Skip to content

Commit

Permalink
try avoiding copies
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 23, 2024
1 parent a937f90 commit efe7c4f
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@
*/
package io.bioimage.modelrunner.pytorch.javacpp.shm;

import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Arrays;

import org.bytedeco.pytorch.Tensor;

import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
Expand Down Expand Up @@ -96,7 +99,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
RandomAccessibleInterval<?> rai = shma.getSharedRAI();
rai = ImgLib2Builder.build(tensor);
if (PlatformDetection.isWindows()) shma.close();
}

Expand Down

0 comments on commit efe7c4f

Please sign in to comment.