1+ package  org .pytorch .executorch ;
2+ 
3+ import  java .nio .ByteBuffer ;
4+ import  java .nio .ByteOrder ;
5+ import  java .nio .FloatBuffer ;
6+ import  java .nio .IntBuffer ;
7+ import  java .nio .LongBuffer ;
8+ 
9+ /** 
10+  * Represents a multi-dimensional array (tensor) used for numerical computation in Executorch. 
11+  * This class wraps a native Executorch tensor and provides methods for its creation, 
12+  * access to its properties (shape, data type), and its underlying data buffer. 
13+  * <p> 
14+  * Tensor instances are {@link AutoCloseable} and must be closed to release native resources 
15+  * when they are no longer needed. Failure to do so can lead to memory leaks in native memory. 
16+  */ 
17+ public  class  Tensor  implements  AutoCloseable  {
18+     private  long  mNativePtr ;
19+ 
20+     /** 
21+      * Loads the native JNI library for Executorch operations. 
22+      * This static block ensures the library is loaded once when the class is first accessed. 
23+      */ 
24+     static  {
25+         System .loadLibrary ("executorch_jni" );
26+     }
27+ 
28+     // Native methods for existing Tensor operations. 
29+     private  native  long  nativeFromBlob (ByteBuffer  buffer , long [] sizes , int  dtypeValue );
30+     private  native  long [] nativeGetSizes (long  nativePtr );
31+     private  native  int  nativeGetDType (long  nativePtr );
32+     private  native  ByteBuffer  nativeGetDataBuffer (long  nativePtr );
33+     private  native  void  nativeClose (long  nativePtr );
34+ 
35+     // New native methods for creating specialized tensors. 
36+     private  static  native  long  nativeCreateOnesTensor (long [] shape , int  dtypeValue );
37+     private  static  native  long  nativeCreateZerosTensor (long [] shape , int  dtypeValue );
38+ 
39+     /** 
40+      * Private constructor to encapsulate native pointer management. 
41+      * Instances of {@code Tensor} should be created using static factory methods. 
42+      * 
43+      * @param nativePtr The pointer to the native {@code executorch::Tensor} object. 
44+      * @throws IllegalArgumentException If {@code nativePtr} is 0, indicating an invalid native object. 
45+      */ 
46+     private  Tensor (long  nativePtr ) {
47+         if  (nativePtr  == 0 ) {
48+             throw  new  IllegalArgumentException ("Native tensor pointer cannot be 0." );
49+         }
50+         this .mNativePtr  = nativePtr ;
51+     }
52+ 
53+     /** 
54+      * Creates a new {@code Tensor} by copying data from a direct {@link ByteBuffer} and 
55+      * specifies its shape and data type. The {@code ByteBuffer} must contain data 
56+      * in the specified {@code DType} and its capacity must match the total number of 
57+      * elements implied by the shape and dtype. 
58+      * 
59+      * @param buffer The direct {@link ByteBuffer} containing the tensor data. 
60+      *               The buffer's position should be at the start of the data and its limit 
61+      *               should define the end of the data. 
62+      * @param shape  An array of long integers representing the dimensions of the tensor. 
63+      *               For example, `{2, 3}` for a 2x3 matrix. 
64+      * @param dtype  The data type for the tensor's elements. 
65+      * @return A new {@code Tensor} instance initialized with the provided data. 
66+      * @throws IllegalArgumentException If {@code buffer}, {@code shape}, or {@code dtype} is null, 
67+      *                                  or if the {@code shape} array is empty. 
68+      */ 
69+     public  static  Tensor  fromBlob (ByteBuffer  buffer , long [] shape , DType  dtype ) {
70+         if  (buffer  == null ) {
71+             throw  new  IllegalArgumentException ("Input buffer cannot be null." );
72+         }
73+         if  (shape  == null  || shape .length  == 0 ) {
74+             throw  new  IllegalArgumentException ("Shape array cannot be null or empty." );
75+         }
76+         if  (dtype  == null ) {
77+             throw  new  IllegalArgumentException ("DType cannot be null." );
78+         }
79+ 
80+         long  nativePtr  = nativeFromBlob (buffer , shape , dtype .getValue ());
81+         return  new  Tensor (nativePtr );
82+     }
83+ 
84+     /** 
85+      * Creates a new tensor filled with ones, using the specified shape and a default 
86+      * data type of {@code DType.FLOAT32}. 
87+      * 
88+      * @param shape An array of long integers representing the dimensions of the tensor. 
89+      * @return A new {@code Tensor} instance initialized with ones. 
90+      * @throws IllegalArgumentException If the {@code shape} array is null or empty. 
91+      */ 
92+     public  static  Tensor  ones (long [] shape ) {
93+         return  ones (shape , DType .FLOAT32 ); // Default to FLOAT32, a common floating-point type. 
94+     }
95+ 
96+     /** 
97+      * Creates a new tensor filled with ones, using the specified shape and data type. 
98+      * The elements of the tensor will be initialized to the numerical value of '1' 
99+      * for the given data type (e.g., 1.0f for FLOAT32, 1L for INT64). 
100+      * 
101+      * @param shape An array of long integers representing the dimensions of the tensor. 
102+      * @param dtype The data type for the tensor's elements. 
103+      * @return A new {@code Tensor} instance initialized with ones. 
104+      * @throws IllegalArgumentException If the {@code shape} array or {@code dtype} is null, 
105+      *                                  or if the {@code shape} array is empty. 
106+      */ 
107+     public  static  Tensor  ones (long [] shape , DType  dtype ) {
108+         if  (shape  == null  || shape .length  == 0 ) {
109+             throw  new  IllegalArgumentException ("Shape array cannot be null or empty." );
110+         }
111+         if  (dtype  == null ) {
112+             throw  new  IllegalArgumentException ("DType cannot be null." );
113+         }
114+         long  nativePtr  = nativeCreateOnesTensor (shape , dtype .getValue ());
115+         return  new  Tensor (nativePtr );
116+     }
117+ 
118+     /** 
119+      * Creates a new tensor filled with zeros, using the specified shape and a default 
120+      * data type of {@code DType.FLOAT32}. 
121+      * 
122+      * @param shape An array of long integers representing the dimensions of the tensor. 
123+      * @return A new {@code Tensor} instance initialized with zeros. 
124+      * @throws IllegalArgumentException If the {@code shape} array is null or empty. 
125+      */ 
126+     public  static  Tensor  zeros (long [] shape ) {
127+         return  zeros (shape , DType .FLOAT32 ); // Default to FLOAT32, a common floating-point type. 
128+     }
129+ 
130+     /** 
131+      * Creates a new tensor filled with zeros, using the specified shape and data type. 
132+      * The elements of the tensor will be initialized to the numerical value of '0' 
133+      * for the given data type (e.g., 0.0f for FLOAT32, 0L for INT64). 
134+      * 
135+      * @param shape An array of long integers representing the dimensions of the tensor. 
136+      * @param dtype The data type for the tensor's elements. 
137+      * @return A new {@code Tensor} instance initialized with zeros. 
138+      * @throws IllegalArgumentException If the {@code shape} array or {@code dtype} is null, 
139+      *                                  or if the {@code shape} array is empty. 
140+      */ 
141+     public  static  Tensor  zeros (long [] shape , DType  dtype ) {
142+         if  (shape  == null  || shape .length  == 0 ) {
143+             throw  new  IllegalArgumentException ("Shape array cannot be null or empty." );
144+         }
145+         if  (dtype  == null ) {
146+             throw  new  IllegalArgumentException ("DType cannot be null." );
147+         }
148+         long  nativePtr  = nativeCreateZerosTensor (shape , dtype .getValue ());
149+         return  new  Tensor (nativePtr );
150+     }
151+ 
152+     /** 
153+      * Returns an array of long integers representing the dimensions (sizes) of this tensor. 
154+      * 
155+      * @return A new long array indicating the size of each dimension. 
156+      */ 
157+     public  long [] getSizes () {
158+         return  nativeGetSizes (mNativePtr );
159+     }
160+ 
161+     /** 
162+      * Returns the data type of the elements stored in this tensor. 
163+      * 
164+      * @return The {@code DType} enum value representing the tensor's data type. 
165+      */ 
166+     public  DType  getDType () {
167+         return  DType .fromValue (nativeGetDType (mNativePtr ));
168+     }
169+ 
170+     /** 
171+      * Returns a direct {@link ByteBuffer} that provides access to the underlying raw data of the tensor. 
172+      * The returned buffer is direct and its byte order is set to {@link ByteOrder#nativeOrder()}. 
173+      * The buffer's position and limit are set to encompass the entire tensor data. 
174+      * Modifying this buffer will modify the tensor's underlying data. 
175+      * 
176+      * @return A {@link ByteBuffer} providing direct access to the tensor's data. 
177+      */ 
178+     public  ByteBuffer  getByteBuffer () {
179+         return  nativeGetDataBuffer (mNativePtr );
180+     }
181+ 
182+     /** 
183+      * Returns a {@link FloatBuffer} that provides access to the underlying data of the tensor. 
184+      * This method assumes the tensor's data type is {@code DType.FLOAT32}. 
185+      * The buffer's byte order is set to {@link ByteOrder#nativeOrder()}. 
186+      * Modifying this buffer will modify the tensor's underlying data. 
187+      * 
188+      * @return A {@link FloatBuffer} providing access to the tensor's data. 
189+      * @throws IllegalStateException If the tensor's data type is not {@code DType.FLOAT32}. 
190+      */ 
191+     public  FloatBuffer  getFloatBuffer () {
192+         if  (getDType () != DType .FLOAT32 ) {
193+             throw  new  IllegalStateException ("Tensor is not of FLOAT32 type. Actual type: "  + getDType ());
194+         }
195+         return  getByteBuffer ().order (ByteOrder .nativeOrder ()).asFloatBuffer ();
196+     }
197+ 
198+     /** 
199+      * Returns an {@link IntBuffer} that provides access to the underlying data of the tensor. 
200+      * This method assumes the tensor's data type is {@code DType.INT32}. 
201+      * The buffer's byte order is set to {@link ByteOrder#nativeOrder()}. 
202+      * Modifying this buffer will modify the tensor's underlying data. 
203+      * 
204+      * @return An {@link IntBuffer} providing access to the tensor's data. 
205+      * @throws IllegalStateException If the tensor's data type is not {@code DType.INT32}. 
206+      */ 
207+     public  IntBuffer  getIntBuffer () {
208+         if  (getDType () != DType .INT32 ) {
209+             throw  new  IllegalStateException ("Tensor is not of INT32 type. Actual type: "  + getDType ());
210+         }
211+         return  getByteBuffer ().order (ByteOrder .nativeOrder ()).asIntBuffer ();
212+     }
213+ 
214+     /** 
215+      * Returns a {@link LongBuffer} that provides access to the underlying data of the tensor. 
216+      * This method assumes the tensor's data type is {@code DType.INT64}. 
217+      * The buffer's byte order is set to {@link ByteOrder#nativeOrder()}. 
218+      * Modifying this buffer will modify the tensor's underlying data. 
219+      * 
220+      * @return A {@link LongBuffer} providing access to the tensor's data. 
221+      * @throws IllegalStateException If the tensor's data type is not {@code DType.INT64}. 
222+      */ 
223+     public  LongBuffer  getLongBuffer () {
224+         if  (getDType () != DType .INT64 ) {
225+             throw  new  IllegalStateException ("Tensor is not of INT64 type. Actual type: "  + getDType ());
226+         }
227+         return  getByteBuffer ().order (ByteOrder .nativeOrder ()).asLongBuffer ();
228+     }
229+ 
230+     /** 
231+      * Releases the native resources associated with this tensor. 
232+      * After calling this method, the tensor object becomes invalid and its native pointer 
233+      * {@code mNativePtr} is set to 0. Any subsequent calls to methods interacting with 
234+      * native resources on this object will likely fail or lead to undefined behavior. 
235+      * This method can be called multiple times; subsequent calls on an already closed 
236+      * tensor will have no effect. 
237+      */ 
238+     @ Override 
239+     public  void  close () {
240+         if  (mNativePtr  != 0 ) {
241+             nativeClose (mNativePtr );
242+             mNativePtr  = 0 ; // Mark as released 
243+         }
244+     }
245+ 
246+     /** 
247+      * Called by the garbage collector on an object when garbage collection determines that 
248+      * there are no more references to the object. 
249+      * It attempts to release native resources by calling {@link #close()} if it has not 
250+      * been explicitly called by the user. 
251+      * <p> 
252+      * Note: Finalization is not guaranteed to run, and its timing is unpredictable. 
253+      * It is highly recommended to explicitly call {@link #close()} to manage native resources 
254+      * and avoid potential memory leaks or resource exhaustion. This {@code finalize} method 
255+      * serves as a last-resort cleanup mechanism. 
256+      * 
257+      * @throws Throwable if an error occurs during finalization. 
258+      */ 
259+     @ SuppressWarnings ("FinalizeDoesntCallSuperFinalize" ) // Super.finalize is called in the finally block 
260+     @ Override 
261+     protected  void  finalize () throws  Throwable  {
262+         try  {
263+             close ();
264+         } finally  {
265+             super .finalize (); // Ensure superclass finalization logic is also executed 
266+         }
267+     }
268+ }
0 commit comments