Skip to content

Commit

Permalink
improve robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Mar 25, 2024
1 parent e8d0cb6 commit 0400eeb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
package io.bioimage.modelrunner.pytorch.javacpp.tensor;


import java.util.Arrays;

import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.IndexingUtils;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
Expand Down Expand Up @@ -84,6 +87,9 @@ public static <T extends Type<T>> RandomAccessibleInterval<T> build(org.bytedeco
*/
public static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(org.bytedeco.pytorch.Tensor tensor) {
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
Expand All @@ -105,6 +111,9 @@ public static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(or
public static RandomAccessibleInterval<ByteType> buildFromTensorByte(org.bytedeco.pytorch.Tensor tensor)
{
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per byte output tensor supported: " + Integer.MAX_VALUE / 1);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
Expand All @@ -125,6 +134,9 @@ public static RandomAccessibleInterval<ByteType> buildFromTensorByte(org.bytedec
public static RandomAccessibleInterval<IntType> buildFromTensorInt(org.bytedeco.pytorch.Tensor tensor)
{
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
Expand All @@ -145,6 +157,9 @@ public static RandomAccessibleInterval<IntType> buildFromTensorInt(org.bytedeco.
public static RandomAccessibleInterval<FloatType> buildFromTensorFloat(org.bytedeco.pytorch.Tensor tensor)
{
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
Expand All @@ -165,6 +180,9 @@ public static RandomAccessibleInterval<FloatType> buildFromTensorFloat(org.byted
public static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(org.bytedeco.pytorch.Tensor tensor)
{
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
Expand All @@ -185,6 +203,9 @@ public static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(org.byt
public static RandomAccessibleInterval<LongType> buildFromTensorLong(org.bytedeco.pytorch.Tensor tensor)
{
long[] arrayShape = tensor.shape();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
long flatSize = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 1))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per byte tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -132,9 +132,9 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 4))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -165,9 +165,9 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 4))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -198,9 +198,9 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 8))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down

0 comments on commit 0400eeb

Please sign in to comment.