21
21
package io .bioimage .modelrunner .pytorch .javacpp .tensor ;
22
22
23
23
24
+ import java .util .Arrays ;
25
+
24
26
import io .bioimage .modelrunner .tensor .Utils ;
27
+ import io .bioimage .modelrunner .utils .CommonUtils ;
25
28
import io .bioimage .modelrunner .utils .IndexingUtils ;
26
29
import net .imglib2 .Cursor ;
27
30
import net .imglib2 .RandomAccessibleInterval ;
@@ -84,6 +87,9 @@ public static <T extends Type<T>> RandomAccessibleInterval<T> build(org.bytedeco
84
87
*/
85
88
public static RandomAccessibleInterval <UnsignedByteType > buildFromTensorUByte (org .bytedeco .pytorch .Tensor tensor ) {
86
89
long [] arrayShape = tensor .shape ();
90
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
91
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
92
+ + " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
87
93
long [] tensorShape = new long [arrayShape .length ];
88
94
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
89
95
long flatSize = 1 ;
@@ -105,6 +111,9 @@ public static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(or
105
111
public static RandomAccessibleInterval <ByteType > buildFromTensorByte (org .bytedeco .pytorch .Tensor tensor )
106
112
{
107
113
long [] arrayShape = tensor .shape ();
114
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
115
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
116
+ + " is too big. Max number of elements per byte output tensor supported: " + Integer .MAX_VALUE / 1 );
108
117
long [] tensorShape = new long [arrayShape .length ];
109
118
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
110
119
long flatSize = 1 ;
@@ -125,6 +134,9 @@ public static RandomAccessibleInterval<ByteType> buildFromTensorByte(org.bytedec
125
134
public static RandomAccessibleInterval <IntType > buildFromTensorInt (org .bytedeco .pytorch .Tensor tensor )
126
135
{
127
136
long [] arrayShape = tensor .shape ();
137
+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
138
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
139
+ + " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
128
140
long [] tensorShape = new long [arrayShape .length ];
129
141
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
130
142
long flatSize = 1 ;
@@ -145,6 +157,9 @@ public static RandomAccessibleInterval<IntType> buildFromTensorInt(org.bytedeco.
145
157
public static RandomAccessibleInterval <FloatType > buildFromTensorFloat (org .bytedeco .pytorch .Tensor tensor )
146
158
{
147
159
long [] arrayShape = tensor .shape ();
160
+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
161
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
162
+ + " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
148
163
long [] tensorShape = new long [arrayShape .length ];
149
164
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
150
165
long flatSize = 1 ;
@@ -165,6 +180,9 @@ public static RandomAccessibleInterval<FloatType> buildFromTensorFloat(org.byted
165
180
public static RandomAccessibleInterval <DoubleType > buildFromTensorDouble (org .bytedeco .pytorch .Tensor tensor )
166
181
{
167
182
long [] arrayShape = tensor .shape ();
183
+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
184
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
185
+ + " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
168
186
long [] tensorShape = new long [arrayShape .length ];
169
187
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
170
188
long flatSize = 1 ;
@@ -185,6 +203,9 @@ public static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(org.byt
185
203
public static RandomAccessibleInterval <LongType > buildFromTensorLong (org .bytedeco .pytorch .Tensor tensor )
186
204
{
187
205
long [] arrayShape = tensor .shape ();
206
+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
207
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
208
+ + " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
188
209
long [] tensorShape = new long [arrayShape .length ];
189
210
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
190
211
long flatSize = 1 ;
0 commit comments