1+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2+ // All rights reserved.
3+ //
4+ // This source code is licensed under the BSD-style license found in the
5+ // LICENSE file in the root directory of this source tree.
6+
7+ package org .pytorch .executorch ;
8+
9+ import androidx .annotation .NonNull ;
10+
11+ /**
12+ * Represents a tensor in ExecuTorch, managing its native memory and providing operations
13+ * to interact with it.
14+ *
15+ * This class is designed to be used by Android applications to create, manipulate,
16+ * and pass tensors to ExecuTorch models. It implements {@link AutoCloseable} to
17+ * ensure proper release of native resources.
18+ */
19+ public class Tensor implements AutoCloseable {
20+
21+ // The pointer to the native (C++) Tensor object.
22+ // A value of 0 indicates that the native object has been released.
23+ private long nativeHandle ;
24+
25+ // --- Existing Native Method Declarations ---
26+ private static native long nativeNew (Object data , long [] shape , int dtype );
27+ private static native void nativeRelease (long nativeHandle );
28+ private static native long [] nativeGetShape (long nativeHandle );
29+ private static native int nativeGetDtype (long nativeHandle );
30+
31+ // --- NEW NATIVE DECLARATIONS ---
32+ private static native long nativeOnes (long [] shape , int dtype );
33+ private static native long nativeZeros (long [] shape , int dtype );
34+
35+ /**
36+ * Constructs a Tensor object from a native handle.
37+ * This constructor is primarily used internally after JNI calls create a native tensor.
38+ *
39+ * @param nativeHandle The native pointer to the underlying C++ Tensor object. Must not be 0.
40+ * @throws IllegalArgumentException if the nativeHandle is 0.
41+ */
42+ public Tensor (long nativeHandle ) {
43+ if (nativeHandle == 0 ) {
44+ throw new IllegalArgumentException ("Native handle cannot be 0." );
45+ }
46+ this .nativeHandle = nativeHandle ;
47+ }
48+
49+ /**
50+ * Creates a new tensor from a flat array of float data and a shape.
51+ * The data type of the tensor will be {@code ScalarType.FLOAT}.
52+ *
53+ * @param data The flat array containing the tensor's float data.
54+ * @param shape The desired shape of the tensor. An empty array {@code new long[0]}
55+ * represents a scalar (0-D) tensor.
56+ * @return A new Tensor.
57+ * @throws IllegalArgumentException if data is null, shape is null, or native allocation fails.
58+ */
59+ public static Tensor fromBlob (@ NonNull float [] data , @ NonNull long [] shape ) {
60+ if (data == null ) {
61+ throw new IllegalArgumentException ("Data cannot be null." );
62+ }
63+ if (shape == null ) {
64+ throw new IllegalArgumentException ("Shape cannot be null." );
65+ }
66+ // It's generally good practice for the Java side to validate data.length against product(shape)
67+ // for early error detection, but we rely on the native side for now.
68+ long nativeHandle = nativeNew (data , shape , ScalarType .FLOAT .getValue ());
69+ if (nativeHandle == 0 ) {
70+ throw new IllegalArgumentException ("Failed to create native Tensor from float blob." );
71+ }
72+ return new Tensor (nativeHandle );
73+ }
74+
75+ // TODO: Add other `fromBlob` overloads for different primitive types (e.g., int[], byte[]).
76+
77+ /**
78+ * Returns the shape of the tensor.
79+ *
80+ * @return An array of long representing the dimensions of the tensor. An empty array signifies
81+ * a scalar (0-D) tensor.
82+ * @throws IllegalStateException if the tensor has been released.
83+ */
84+ @ NonNull
85+ public long [] getShape () {
86+ if (nativeHandle == 0 ) {
87+ throw new IllegalStateException ("Tensor has been released." );
88+ }
89+ return nativeGetShape (nativeHandle );
90+ }
91+
92+ /**
93+ * Returns the data type of the tensor.
94+ *
95+ * @return The {@code ScalarType} enum value representing the tensor's data type.
96+ * @throws IllegalStateException if the tensor has been released.
97+ */
98+ @ NonNull
99+ public ScalarType getDtype () {
100+ if (nativeHandle == 0 ) {
101+ throw new IllegalStateException ("Tensor has been released." );
102+ }
103+ return ScalarType .fromValue (nativeGetDtype (nativeHandle ));
104+ }
105+
106+ // --- NEW PUBLIC STATIC CONVENIENCE METHODS ---
107+
108+ /**
109+ * Creates a new tensor with the specified shape and fills it with ones.
110+ * The data type of the tensor will be {@code ScalarType.FLOAT} by default.
111+ *
112+ * @param shape The desired shape of the tensor. An empty array {@code new long[0]}
113+ * represents a scalar (0-D) tensor.
114+ * @return A new Tensor filled with ones.
115+ * @throws IllegalArgumentException if the shape is null or native allocation fails.
116+ */
117+ public static Tensor ones (@ NonNull long [] shape ) {
118+ return ones (shape , ScalarType .FLOAT );
119+ }
120+
121+ /**
122+ * Creates a new tensor with the specified shape and fills it with ones.
123+ *
124+ * @param shape The desired shape of the tensor. An empty array {@code new long[0]}
125+ * represents a scalar (0-D) tensor.
126+ * @param dtype The desired data type of the tensor.
127+ * @return A new Tensor filled with ones.
128+ * @throws IllegalArgumentException if the shape is null, dtype is null, or native allocation fails.
129+ */
130+ public static Tensor ones (@ NonNull long [] shape , @ NonNull ScalarType dtype ) {
131+ if (shape == null ) {
132+ throw new IllegalArgumentException ("Shape cannot be null." );
133+ }
134+ if (dtype == null ) {
135+ throw new IllegalArgumentException ("Dtype cannot be null." );
136+ }
137+ long nativeHandle = nativeOnes (shape , dtype .getValue ());
138+ if (nativeHandle == 0 ) {
139+ throw new IllegalArgumentException ("Failed to create native Tensor with ones." );
140+ }
141+ return new Tensor (nativeHandle );
142+ }
143+
144+ /**
145+ * Creates a new tensor with the specified shape and fills it with zeros.
146+ * The data type of the tensor will be {@code ScalarType.FLOAT} by default.
147+ *
148+ * @param shape The desired shape of the tensor. An empty array {@code new long[0]}
149+ * represents a scalar (0-D) tensor.
150+ * @return A new Tensor filled with zeros.
151+ * @throws IllegalArgumentException if the shape is null or native allocation fails.
152+ */
153+ public static Tensor zeros (@ NonNull long [] shape ) {
154+ return zeros (shape , ScalarType .FLOAT );
155+ }
156+
157+ /**
158+ * Creates a new tensor with the specified shape and fills it with zeros.
159+ *
160+ * @param shape The desired shape of the tensor. An empty array {@code new long[0]}
161+ * represents a scalar (0-D) tensor.
162+ * @param dtype The desired data type of the tensor.
163+ * @return A new Tensor filled with zeros.
164+ * @throws IllegalArgumentException if the shape is null, dtype is null, or native allocation fails.
165+ */
166+ public static Tensor zeros (@ NonNull long [] shape , @ NonNull ScalarType dtype ) {
167+ if (shape == null ) {
168+ throw new IllegalArgumentException ("Shape cannot be null." );
169+ }
170+ if (dtype == null ) {
171+ throw new IllegalArgumentException ("Dtype cannot be null." );
172+ }
173+ long nativeHandle = nativeZeros (shape , dtype .getValue ());
174+ if (nativeHandle == 0 ) {
175+ throw new IllegalArgumentException ("Failed to create native Tensor with zeros." );
176+ }
177+ return new Tensor (nativeHandle );
178+ }
179+
180+ /**
181+ * Releases the native resources associated with this tensor.
182+ * After this method is called, the Tensor object becomes invalid.
183+ * This method is automatically called when the Tensor is used in a try-with-resources statement.
184+ */
185+ @ Override
186+ public void close () {
187+ if (nativeHandle != 0 ) {
188+ nativeRelease (nativeHandle );
189+ nativeHandle = 0 ;
190+ }
191+ }
192+
193+ // Static initializer to load the JNI library.
194+ // In a typical Android application, the JNI library might be loaded
195+ // once in a higher-level entry point (e.g., Application class or a Module class).
196+ // This block ensures the library is loaded if not already.
197+ static {
198+ try {
199+ System .loadLibrary ("executorch_android_jni" );
200+ } catch (UnsatisfiedLinkError e ) {
201+ System .err .println ("Failed to load native library 'executorch_android_jni': " + e .getMessage ());
202+ // For a core library class like Tensor, if native functionality is essential,
203+ // it's appropriate to rethrow the error to indicate a critical setup failure.
204+ throw e ;
205+ }
206+ }
207+ }
0 commit comments