23
23
24
24
import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
25
25
import io .bioimage .modelrunner .utils .CommonUtils ;
26
- import net .imglib2 .RandomAccessibleInterval ;
27
- import net .imglib2 .img .Img ;
28
- import net .imglib2 .type .numeric .integer .IntType ;
29
- import net .imglib2 .type .numeric .integer .LongType ;
30
- import net .imglib2 .type .numeric .integer .UnsignedByteType ;
31
- import net .imglib2 .type .numeric .real .DoubleType ;
32
- import net .imglib2 .type .numeric .real .FloatType ;
33
26
import net .imglib2 .util .Cast ;
34
27
35
28
import java .nio .ByteBuffer ;
55
48
import org .tensorflow .types .family .TType ;
56
49
57
50
/**
58
- * A TensorFlow 2 {@link Tensor} builder from {@link Img} and
59
- * {@link io.bioimage.modelrunner.tensor.Tensor} objects.
51
+ * Utility class to build Pytorch Bytedeco tensors from shm segments using {@link SharedMemoryArray}
60
52
*
61
- * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
53
+ * @author Carlos Garcia Lopez de Haro
62
54
*/
63
55
public final class TensorBuilder {
64
56
@@ -68,21 +60,17 @@ public final class TensorBuilder {
68
60
private TensorBuilder () {}
69
61
70
62
/**
71
- * Creates {@link TType} instance with the same size and information as the
72
- * given {@link RandomAccessibleInterval}.
63
+ * Creates {@link TType} instance from a {@link SharedMemoryArray}
73
64
*
74
- * @param <T>
75
- * the ImgLib2 data types the {@link RandomAccessibleInterval} can be
76
65
* @param array
77
- * the {@link RandomAccessibleInterval } that is going to be converted into
66
+ * the {@link SharedMemoryArray } that is going to be converted into
78
67
* a {@link TType} tensor
79
- * @return a {@link TType} tensor
80
- * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval }
68
+ * @return the Pytorch {@link TType} as the one stored in the shared memory segment
69
+ * @throws IllegalArgumentException if the type of the {@link SharedMemoryArray }
81
70
* is not supported
82
71
*/
83
72
public static TType build (SharedMemoryArray array ) throws IllegalArgumentException
84
73
{
85
- // Create an Icy sequence of the same type of the tensor
86
74
if (array .getOriginalDataType ().equals ("uint8" )) {
87
75
return buildUByte (Cast .unchecked (array ));
88
76
}
@@ -103,17 +91,7 @@ else if (array.getOriginalDataType().equals("int64")) {
103
91
}
104
92
}
105
93
106
- /**
107
- * Creates a {@link TType} tensor of type {@link TUint8} from an
108
- * {@link RandomAccessibleInterval} of type {@link UnsignedByteType}
109
- *
110
- * @param tensor
111
- * The {@link RandomAccessibleInterval} to fill the tensor with.
112
- * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data.
113
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
114
- * not compatible
115
- */
116
- public static TUint8 buildUByte (SharedMemoryArray tensor )
94
+ private static TUint8 buildUByte (SharedMemoryArray tensor )
117
95
throws IllegalArgumentException
118
96
{
119
97
long [] ogShape = tensor .getOriginalShape ();
@@ -128,17 +106,7 @@ public static TUint8 buildUByte(SharedMemoryArray tensor)
128
106
return ndarray ;
129
107
}
130
108
131
- /**
132
- * Creates a {@link TInt32} tensor of type {@link TInt32} from an
133
- * {@link RandomAccessibleInterval} of type {@link IntType}
134
- *
135
- * @param tensor
136
- * The {@link RandomAccessibleInterval} to fill the tensor with.
137
- * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data.
138
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
139
- * not compatible
140
- */
141
- public static TInt32 buildInt (SharedMemoryArray tensor )
109
+ private static TInt32 buildInt (SharedMemoryArray tensor )
142
110
throws IllegalArgumentException
143
111
{
144
112
long [] ogShape = tensor .getOriginalShape ();
@@ -157,16 +125,6 @@ public static TInt32 buildInt(SharedMemoryArray tensor)
157
125
return ndarray ;
158
126
}
159
127
160
- /**
161
- * Creates a {@link TInt64} tensor of type {@link TInt64} from an
162
- * {@link RandomAccessibleInterval} of type {@link LongType}
163
- *
164
- * @param tensor
165
- * The {@link RandomAccessibleInterval} to fill the tensor with.
166
- * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data.
167
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
168
- * not compatible
169
- */
170
128
private static TInt64 buildLong (SharedMemoryArray tensor )
171
129
throws IllegalArgumentException
172
130
{
@@ -186,17 +144,7 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
186
144
return ndarray ;
187
145
}
188
146
189
- /**
190
- * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an
191
- * {@link RandomAccessibleInterval} of type {@link FloatType}
192
- *
193
- * @param tensor
194
- * The {@link RandomAccessibleInterval} to fill the tensor with.
195
- * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data.
196
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
197
- * not compatible
198
- */
199
- public static TFloat32 buildFloat (SharedMemoryArray tensor )
147
+ private static TFloat32 buildFloat (SharedMemoryArray tensor )
200
148
throws IllegalArgumentException
201
149
{
202
150
long [] ogShape = tensor .getOriginalShape ();
@@ -214,16 +162,6 @@ public static TFloat32 buildFloat(SharedMemoryArray tensor)
214
162
return ndarray ;
215
163
}
216
164
217
- /**
218
- * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an
219
- * {@link RandomAccessibleInterval} of type {@link DoubleType}
220
- *
221
- * @param tensor
222
- * The {@link RandomAccessibleInterval} to fill the tensor with.
223
- * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data.
224
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
225
- * not compatible
226
- */
227
165
private static TFloat64 buildDouble (SharedMemoryArray tensor )
228
166
throws IllegalArgumentException
229
167
{
0 commit comments