From efe7c4f460c2a2c47ce3be32d525e839571ccb3c Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Sat, 23 Nov 2024 14:07:44 +0100 Subject: [PATCH] try avoiding copies --- .../modelrunner/pytorch/javacpp/shm/ShmBuilder.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java index 6d29a1f..ea45c1a 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -20,18 +20,21 @@ */ package io.bioimage.modelrunner.pytorch.javacpp.shm; +import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder; import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.utils.CommonUtils; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.FloatBuffer; import java.util.Arrays; import org.bytedeco.pytorch.Tensor; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; @@ -96,7 +99,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); + RandomAccessibleInterval rai = shma.getSharedRAI(); + rai = ImgLib2Builder.build(tensor); if (PlatformDetection.isWindows()) shma.close(); }