Skip to content

Commit 0400eeb

Browse files
committed
improve robustness
1 parent e8d0cb6 commit 0400eeb

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/ImgLib2Builder.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
package io.bioimage.modelrunner.pytorch.javacpp.tensor;
2222

2323

24+
import java.util.Arrays;
25+
2426
import io.bioimage.modelrunner.tensor.Utils;
27+
import io.bioimage.modelrunner.utils.CommonUtils;
2528
import io.bioimage.modelrunner.utils.IndexingUtils;
2629
import net.imglib2.Cursor;
2730
import net.imglib2.RandomAccessibleInterval;
@@ -84,6 +87,9 @@ public static <T extends Type<T>> RandomAccessibleInterval<T> build(org.bytedeco
8487
*/
8588
public static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(org.bytedeco.pytorch.Tensor tensor) {
8689
long[] arrayShape = tensor.shape();
90+
if (CommonUtils.int32Overflows(arrayShape, 1))
91+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
92+
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
8793
long[] tensorShape = new long[arrayShape.length];
8894
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
8995
long flatSize = 1;
@@ -105,6 +111,9 @@ public static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(or
105111
public static RandomAccessibleInterval<ByteType> buildFromTensorByte(org.bytedeco.pytorch.Tensor tensor)
106112
{
107113
long[] arrayShape = tensor.shape();
114+
if (CommonUtils.int32Overflows(arrayShape, 1))
115+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
116+
+ " is too big. Max number of elements per byte output tensor supported: " + Integer.MAX_VALUE / 1);
108117
long[] tensorShape = new long[arrayShape.length];
109118
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
110119
long flatSize = 1;
@@ -125,6 +134,9 @@ public static RandomAccessibleInterval<ByteType> buildFromTensorByte(org.bytedec
125134
public static RandomAccessibleInterval<IntType> buildFromTensorInt(org.bytedeco.pytorch.Tensor tensor)
126135
{
127136
long[] arrayShape = tensor.shape();
137+
if (CommonUtils.int32Overflows(arrayShape, 4))
138+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
139+
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
128140
long[] tensorShape = new long[arrayShape.length];
129141
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
130142
long flatSize = 1;
@@ -145,6 +157,9 @@ public static RandomAccessibleInterval<IntType> buildFromTensorInt(org.bytedeco.
145157
public static RandomAccessibleInterval<FloatType> buildFromTensorFloat(org.bytedeco.pytorch.Tensor tensor)
146158
{
147159
long[] arrayShape = tensor.shape();
160+
if (CommonUtils.int32Overflows(arrayShape, 4))
161+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
162+
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
148163
long[] tensorShape = new long[arrayShape.length];
149164
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
150165
long flatSize = 1;
@@ -165,6 +180,9 @@ public static RandomAccessibleInterval<FloatType> buildFromTensorFloat(org.byted
165180
public static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(org.bytedeco.pytorch.Tensor tensor)
166181
{
167182
long[] arrayShape = tensor.shape();
183+
if (CommonUtils.int32Overflows(arrayShape, 8))
184+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
185+
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
168186
long[] tensorShape = new long[arrayShape.length];
169187
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
170188
long flatSize = 1;
@@ -185,6 +203,9 @@ public static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(org.byt
185203
public static RandomAccessibleInterval<LongType> buildFromTensorLong(org.bytedeco.pytorch.Tensor tensor)
186204
{
187205
long[] arrayShape = tensor.shape();
206+
if (CommonUtils.int32Overflows(arrayShape, 8))
207+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
208+
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
188209
long[] tensorShape = new long[arrayShape.length];
189210
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
190211
long flatSize = 1;

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch
9999
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
100100
{
101101
long[] ogShape = tensor.dimensionsAsLongArray();
102-
if (CommonUtils.int32Overflows(ogShape))
102+
if (CommonUtils.int32Overflows(ogShape, 1))
103103
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
104-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
104+
+ " is too big. Max number of elements per byte tensor supported: " + Integer.MAX_VALUE);
105105
tensor = Utils.transpose(tensor);
106106
long[] tensorShape = tensor.dimensionsAsLongArray();
107107
int size = 1;
@@ -132,9 +132,9 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
132132
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
133133
{
134134
long[] ogShape = tensor.dimensionsAsLongArray();
135-
if (CommonUtils.int32Overflows(ogShape))
135+
if (CommonUtils.int32Overflows(ogShape, 4))
136136
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
137-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
137+
+ " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4);
138138
tensor = Utils.transpose(tensor);
139139
long[] tensorShape = tensor.dimensionsAsLongArray();
140140
int size = 1;
@@ -165,9 +165,9 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
165165
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
166166
{
167167
long[] ogShape = tensor.dimensionsAsLongArray();
168-
if (CommonUtils.int32Overflows(ogShape))
168+
if (CommonUtils.int32Overflows(ogShape, 4))
169169
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
170-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
170+
+ " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4);
171171
tensor = Utils.transpose(tensor);
172172
long[] tensorShape = tensor.dimensionsAsLongArray();
173173
int size = 1;
@@ -198,9 +198,9 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
198198
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
199199
{
200200
long[] ogShape = tensor.dimensionsAsLongArray();
201-
if (CommonUtils.int32Overflows(ogShape))
201+
if (CommonUtils.int32Overflows(ogShape, 8))
202202
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
203-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
203+
+ " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8);
204204
tensor = Utils.transpose(tensor);
205205
long[] tensorShape = tensor.dimensionsAsLongArray();
206206
int size = 1;

0 commit comments

Comments
 (0)