Skip to content

Commit b9c5f09

Browse files
committed
fix: Resolve issue #15125
1 parent 1a5eaec commit b9c5f09

File tree

1 file changed

+268
-0
lines changed
  • extension/android/src/main/java/org/pytorch/executorch

1 file changed

+268
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
package org.pytorch.executorch;
2+
3+
import java.nio.ByteBuffer;
4+
import java.nio.ByteOrder;
5+
import java.nio.FloatBuffer;
6+
import java.nio.IntBuffer;
7+
import java.nio.LongBuffer;
8+
9+
/**
10+
* Represents a multi-dimensional array (tensor) used for numerical computation in Executorch.
11+
* This class wraps a native Executorch tensor and provides methods for its creation,
12+
* access to its properties (shape, data type), and its underlying data buffer.
13+
* <p>
14+
* Tensor instances are {@link AutoCloseable} and must be closed to release native resources
15+
* when they are no longer needed. Failure to do so can lead to memory leaks in native memory.
16+
*/
17+
public class Tensor implements AutoCloseable {
18+
private long mNativePtr;
19+
20+
/**
21+
* Loads the native JNI library for Executorch operations.
22+
* This static block ensures the library is loaded once when the class is first accessed.
23+
*/
24+
static {
25+
System.loadLibrary("executorch_jni");
26+
}
27+
28+
// Native methods for existing Tensor operations.
29+
private native long nativeFromBlob(ByteBuffer buffer, long[] sizes, int dtypeValue);
30+
private native long[] nativeGetSizes(long nativePtr);
31+
private native int nativeGetDType(long nativePtr);
32+
private native ByteBuffer nativeGetDataBuffer(long nativePtr);
33+
private native void nativeClose(long nativePtr);
34+
35+
// New native methods for creating specialized tensors.
36+
private static native long nativeCreateOnesTensor(long[] shape, int dtypeValue);
37+
private static native long nativeCreateZerosTensor(long[] shape, int dtypeValue);
38+
39+
/**
40+
* Private constructor to encapsulate native pointer management.
41+
* Instances of {@code Tensor} should be created using static factory methods.
42+
*
43+
* @param nativePtr The pointer to the native {@code executorch::Tensor} object.
44+
* @throws IllegalArgumentException If {@code nativePtr} is 0, indicating an invalid native object.
45+
*/
46+
private Tensor(long nativePtr) {
47+
if (nativePtr == 0) {
48+
throw new IllegalArgumentException("Native tensor pointer cannot be 0.");
49+
}
50+
this.mNativePtr = nativePtr;
51+
}
52+
53+
/**
54+
* Creates a new {@code Tensor} by copying data from a direct {@link ByteBuffer} and
55+
* specifies its shape and data type. The {@code ByteBuffer} must contain data
56+
* in the specified {@code DType} and its capacity must match the total number of
57+
* elements implied by the shape and dtype.
58+
*
59+
* @param buffer The direct {@link ByteBuffer} containing the tensor data.
60+
* The buffer's position should be at the start of the data and its limit
61+
* should define the end of the data.
62+
* @param shape An array of long integers representing the dimensions of the tensor.
63+
* For example, `{2, 3}` for a 2x3 matrix.
64+
* @param dtype The data type for the tensor's elements.
65+
* @return A new {@code Tensor} instance initialized with the provided data.
66+
* @throws IllegalArgumentException If {@code buffer}, {@code shape}, or {@code dtype} is null,
67+
* or if the {@code shape} array is empty.
68+
*/
69+
public static Tensor fromBlob(ByteBuffer buffer, long[] shape, DType dtype) {
70+
if (buffer == null) {
71+
throw new IllegalArgumentException("Input buffer cannot be null.");
72+
}
73+
if (shape == null || shape.length == 0) {
74+
throw new IllegalArgumentException("Shape array cannot be null or empty.");
75+
}
76+
if (dtype == null) {
77+
throw new IllegalArgumentException("DType cannot be null.");
78+
}
79+
80+
long nativePtr = nativeFromBlob(buffer, shape, dtype.getValue());
81+
return new Tensor(nativePtr);
82+
}
83+
84+
/**
85+
* Creates a new tensor filled with ones, using the specified shape and a default
86+
* data type of {@code DType.FLOAT32}.
87+
*
88+
* @param shape An array of long integers representing the dimensions of the tensor.
89+
* @return A new {@code Tensor} instance initialized with ones.
90+
* @throws IllegalArgumentException If the {@code shape} array is null or empty.
91+
*/
92+
public static Tensor ones(long[] shape) {
93+
return ones(shape, DType.FLOAT32); // Default to FLOAT32, a common floating-point type.
94+
}
95+
96+
/**
97+
* Creates a new tensor filled with ones, using the specified shape and data type.
98+
* The elements of the tensor will be initialized to the numerical value of '1'
99+
* for the given data type (e.g., 1.0f for FLOAT32, 1L for INT64).
100+
*
101+
* @param shape An array of long integers representing the dimensions of the tensor.
102+
* @param dtype The data type for the tensor's elements.
103+
* @return A new {@code Tensor} instance initialized with ones.
104+
* @throws IllegalArgumentException If the {@code shape} array or {@code dtype} is null,
105+
* or if the {@code shape} array is empty.
106+
*/
107+
public static Tensor ones(long[] shape, DType dtype) {
108+
if (shape == null || shape.length == 0) {
109+
throw new IllegalArgumentException("Shape array cannot be null or empty.");
110+
}
111+
if (dtype == null) {
112+
throw new IllegalArgumentException("DType cannot be null.");
113+
}
114+
long nativePtr = nativeCreateOnesTensor(shape, dtype.getValue());
115+
return new Tensor(nativePtr);
116+
}
117+
118+
/**
119+
* Creates a new tensor filled with zeros, using the specified shape and a default
120+
* data type of {@code DType.FLOAT32}.
121+
*
122+
* @param shape An array of long integers representing the dimensions of the tensor.
123+
* @return A new {@code Tensor} instance initialized with zeros.
124+
* @throws IllegalArgumentException If the {@code shape} array is null or empty.
125+
*/
126+
public static Tensor zeros(long[] shape) {
127+
return zeros(shape, DType.FLOAT32); // Default to FLOAT32, a common floating-point type.
128+
}
129+
130+
/**
131+
* Creates a new tensor filled with zeros, using the specified shape and data type.
132+
* The elements of the tensor will be initialized to the numerical value of '0'
133+
* for the given data type (e.g., 0.0f for FLOAT32, 0L for INT64).
134+
*
135+
* @param shape An array of long integers representing the dimensions of the tensor.
136+
* @param dtype The data type for the tensor's elements.
137+
* @return A new {@code Tensor} instance initialized with zeros.
138+
* @throws IllegalArgumentException If the {@code shape} array or {@code dtype} is null,
139+
* or if the {@code shape} array is empty.
140+
*/
141+
public static Tensor zeros(long[] shape, DType dtype) {
142+
if (shape == null || shape.length == 0) {
143+
throw new IllegalArgumentException("Shape array cannot be null or empty.");
144+
}
145+
if (dtype == null) {
146+
throw new IllegalArgumentException("DType cannot be null.");
147+
}
148+
long nativePtr = nativeCreateZerosTensor(shape, dtype.getValue());
149+
return new Tensor(nativePtr);
150+
}
151+
152+
/**
153+
* Returns an array of long integers representing the dimensions (sizes) of this tensor.
154+
*
155+
* @return A new long array indicating the size of each dimension.
156+
*/
157+
public long[] getSizes() {
158+
return nativeGetSizes(mNativePtr);
159+
}
160+
161+
/**
162+
* Returns the data type of the elements stored in this tensor.
163+
*
164+
* @return The {@code DType} enum value representing the tensor's data type.
165+
*/
166+
public DType getDType() {
167+
return DType.fromValue(nativeGetDType(mNativePtr));
168+
}
169+
170+
/**
171+
* Returns a direct {@link ByteBuffer} that provides access to the underlying raw data of the tensor.
172+
* The returned buffer is direct and its byte order is set to {@link ByteOrder#nativeOrder()}.
173+
* The buffer's position and limit are set to encompass the entire tensor data.
174+
* Modifying this buffer will modify the tensor's underlying data.
175+
*
176+
* @return A {@link ByteBuffer} providing direct access to the tensor's data.
177+
*/
178+
public ByteBuffer getByteBuffer() {
179+
return nativeGetDataBuffer(mNativePtr);
180+
}
181+
182+
/**
183+
* Returns a {@link FloatBuffer} that provides access to the underlying data of the tensor.
184+
* This method assumes the tensor's data type is {@code DType.FLOAT32}.
185+
* The buffer's byte order is set to {@link ByteOrder#nativeOrder()}.
186+
* Modifying this buffer will modify the tensor's underlying data.
187+
*
188+
* @return A {@link FloatBuffer} providing access to the tensor's data.
189+
* @throws IllegalStateException If the tensor's data type is not {@code DType.FLOAT32}.
190+
*/
191+
public FloatBuffer getFloatBuffer() {
192+
if (getDType() != DType.FLOAT32) {
193+
throw new IllegalStateException("Tensor is not of FLOAT32 type. Actual type: " + getDType());
194+
}
195+
return getByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer();
196+
}
197+
198+
/**
199+
* Returns an {@link IntBuffer} that provides access to the underlying data of the tensor.
200+
* This method assumes the tensor's data type is {@code DType.INT32}.
201+
* The buffer's byte order is set to {@link ByteOrder#nativeOrder()}.
202+
* Modifying this buffer will modify the tensor's underlying data.
203+
*
204+
* @return An {@link IntBuffer} providing access to the tensor's data.
205+
* @throws IllegalStateException If the tensor's data type is not {@code DType.INT32}.
206+
*/
207+
public IntBuffer getIntBuffer() {
208+
if (getDType() != DType.INT32) {
209+
throw new IllegalStateException("Tensor is not of INT32 type. Actual type: " + getDType());
210+
}
211+
return getByteBuffer().order(ByteOrder.nativeOrder()).asIntBuffer();
212+
}
213+
214+
/**
215+
* Returns a {@link LongBuffer} that provides access to the underlying data of the tensor.
216+
* This method assumes the tensor's data type is {@code DType.INT64}.
217+
* The buffer's byte order is set to {@link ByteOrder#nativeOrder()}.
218+
* Modifying this buffer will modify the tensor's underlying data.
219+
*
220+
* @return A {@link LongBuffer} providing access to the tensor's data.
221+
* @throws IllegalStateException If the tensor's data type is not {@code DType.INT64}.
222+
*/
223+
public LongBuffer getLongBuffer() {
224+
if (getDType() != DType.INT64) {
225+
throw new IllegalStateException("Tensor is not of INT64 type. Actual type: " + getDType());
226+
}
227+
return getByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer();
228+
}
229+
230+
/**
231+
* Releases the native resources associated with this tensor.
232+
* After calling this method, the tensor object becomes invalid and its native pointer
233+
* {@code mNativePtr} is set to 0. Any subsequent calls to methods interacting with
234+
* native resources on this object will likely fail or lead to undefined behavior.
235+
* This method can be called multiple times; subsequent calls on an already closed
236+
* tensor will have no effect.
237+
*/
238+
@Override
239+
public void close() {
240+
if (mNativePtr != 0) {
241+
nativeClose(mNativePtr);
242+
mNativePtr = 0; // Mark as released
243+
}
244+
}
245+
246+
/**
247+
* Called by the garbage collector on an object when garbage collection determines that
248+
* there are no more references to the object.
249+
* It attempts to release native resources by calling {@link #close()} if it has not
250+
* been explicitly called by the user.
251+
* <p>
252+
* Note: Finalization is not guaranteed to run, and its timing is unpredictable.
253+
* It is highly recommended to explicitly call {@link #close()} to manage native resources
254+
* and avoid potential memory leaks or resource exhaustion. This {@code finalize} method
255+
* serves as a last-resort cleanup mechanism.
256+
*
257+
* @throws Throwable if an error occurs during finalization.
258+
*/
259+
@SuppressWarnings("FinalizeDoesntCallSuperFinalize") // Super.finalize is called in the finally block
260+
@Override
261+
protected void finalize() throws Throwable {
262+
try {
263+
close();
264+
} finally {
265+
super.finalize(); // Ensure superclass finalization logic is also executed
266+
}
267+
}
268+
}

0 commit comments

Comments
 (0)