Skip to content

Commit 1e91562

Browse files
committed
fix: Resolve issue #15125
1 parent 1a5eaec commit 1e91562

File tree

1 file changed

+207
-0
lines changed
  • extension/android/java/org/pytorch/executorch

1 file changed

+207
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
package org.pytorch.executorch;
8+
9+
import androidx.annotation.NonNull;
10+
11+
/**
12+
* Represents a tensor in ExecuTorch, managing its native memory and providing operations
13+
* to interact with it.
14+
*
15+
* This class is designed to be used by Android applications to create, manipulate,
16+
* and pass tensors to ExecuTorch models. It implements {@link AutoCloseable} to
17+
* ensure proper release of native resources.
18+
*/
19+
public class Tensor implements AutoCloseable {
20+
21+
// The pointer to the native (C++) Tensor object.
22+
// A value of 0 indicates that the native object has been released.
23+
private long nativeHandle;
24+
25+
// --- Existing Native Method Declarations ---
26+
private static native long nativeNew(Object data, long[] shape, int dtype);
27+
private static native void nativeRelease(long nativeHandle);
28+
private static native long[] nativeGetShape(long nativeHandle);
29+
private static native int nativeGetDtype(long nativeHandle);
30+
31+
// --- NEW NATIVE DECLARATIONS ---
32+
private static native long nativeOnes(long[] shape, int dtype);
33+
private static native long nativeZeros(long[] shape, int dtype);
34+
35+
/**
36+
* Constructs a Tensor object from a native handle.
37+
* This constructor is primarily used internally after JNI calls create a native tensor.
38+
*
39+
* @param nativeHandle The native pointer to the underlying C++ Tensor object. Must not be 0.
40+
* @throws IllegalArgumentException if the nativeHandle is 0.
41+
*/
42+
public Tensor(long nativeHandle) {
43+
if (nativeHandle == 0) {
44+
throw new IllegalArgumentException("Native handle cannot be 0.");
45+
}
46+
this.nativeHandle = nativeHandle;
47+
}
48+
49+
/**
50+
* Creates a new tensor from a flat array of float data and a shape.
51+
* The data type of the tensor will be {@code ScalarType.FLOAT}.
52+
*
53+
* @param data The flat array containing the tensor's float data.
54+
* @param shape The desired shape of the tensor. An empty array {@code new long[0]}
55+
* represents a scalar (0-D) tensor.
56+
* @return A new Tensor.
57+
* @throws IllegalArgumentException if data is null, shape is null, or native allocation fails.
58+
*/
59+
public static Tensor fromBlob(@NonNull float[] data, @NonNull long[] shape) {
60+
if (data == null) {
61+
throw new IllegalArgumentException("Data cannot be null.");
62+
}
63+
if (shape == null) {
64+
throw new IllegalArgumentException("Shape cannot be null.");
65+
}
66+
// It's generally good practice for the Java side to validate data.length against product(shape)
67+
// for early error detection, but we rely on the native side for now.
68+
long nativeHandle = nativeNew(data, shape, ScalarType.FLOAT.getValue());
69+
if (nativeHandle == 0) {
70+
throw new IllegalArgumentException("Failed to create native Tensor from float blob.");
71+
}
72+
return new Tensor(nativeHandle);
73+
}
74+
75+
// TODO: Add other `fromBlob` overloads for different primitive types (e.g., int[], byte[]).
76+
77+
/**
78+
* Returns the shape of the tensor.
79+
*
80+
* @return An array of long representing the dimensions of the tensor. An empty array signifies
81+
* a scalar (0-D) tensor.
82+
* @throws IllegalStateException if the tensor has been released.
83+
*/
84+
@NonNull
85+
public long[] getShape() {
86+
if (nativeHandle == 0) {
87+
throw new IllegalStateException("Tensor has been released.");
88+
}
89+
return nativeGetShape(nativeHandle);
90+
}
91+
92+
/**
93+
* Returns the data type of the tensor.
94+
*
95+
* @return The {@code ScalarType} enum value representing the tensor's data type.
96+
* @throws IllegalStateException if the tensor has been released.
97+
*/
98+
@NonNull
99+
public ScalarType getDtype() {
100+
if (nativeHandle == 0) {
101+
throw new IllegalStateException("Tensor has been released.");
102+
}
103+
return ScalarType.fromValue(nativeGetDtype(nativeHandle));
104+
}
105+
106+
// --- NEW PUBLIC STATIC CONVENIENCE METHODS ---
107+
108+
/**
109+
* Creates a new tensor with the specified shape and fills it with ones.
110+
* The data type of the tensor will be {@code ScalarType.FLOAT} by default.
111+
*
112+
* @param shape The desired shape of the tensor. An empty array {@code new long[0]}
113+
* represents a scalar (0-D) tensor.
114+
* @return A new Tensor filled with ones.
115+
* @throws IllegalArgumentException if the shape is null or native allocation fails.
116+
*/
117+
public static Tensor ones(@NonNull long[] shape) {
118+
return ones(shape, ScalarType.FLOAT);
119+
}
120+
121+
/**
122+
* Creates a new tensor with the specified shape and fills it with ones.
123+
*
124+
* @param shape The desired shape of the tensor. An empty array {@code new long[0]}
125+
* represents a scalar (0-D) tensor.
126+
* @param dtype The desired data type of the tensor.
127+
* @return A new Tensor filled with ones.
128+
* @throws IllegalArgumentException if the shape is null, dtype is null, or native allocation fails.
129+
*/
130+
public static Tensor ones(@NonNull long[] shape, @NonNull ScalarType dtype) {
131+
if (shape == null) {
132+
throw new IllegalArgumentException("Shape cannot be null.");
133+
}
134+
if (dtype == null) {
135+
throw new IllegalArgumentException("Dtype cannot be null.");
136+
}
137+
long nativeHandle = nativeOnes(shape, dtype.getValue());
138+
if (nativeHandle == 0) {
139+
throw new IllegalArgumentException("Failed to create native Tensor with ones.");
140+
}
141+
return new Tensor(nativeHandle);
142+
}
143+
144+
/**
145+
* Creates a new tensor with the specified shape and fills it with zeros.
146+
* The data type of the tensor will be {@code ScalarType.FLOAT} by default.
147+
*
148+
* @param shape The desired shape of the tensor. An empty array {@code new long[0]}
149+
* represents a scalar (0-D) tensor.
150+
* @return A new Tensor filled with zeros.
151+
* @throws IllegalArgumentException if the shape is null or native allocation fails.
152+
*/
153+
public static Tensor zeros(@NonNull long[] shape) {
154+
return zeros(shape, ScalarType.FLOAT);
155+
}
156+
157+
/**
158+
* Creates a new tensor with the specified shape and fills it with zeros.
159+
*
160+
* @param shape The desired shape of the tensor. An empty array {@code new long[0]}
161+
* represents a scalar (0-D) tensor.
162+
* @param dtype The desired data type of the tensor.
163+
* @return A new Tensor filled with zeros.
164+
* @throws IllegalArgumentException if the shape is null, dtype is null, or native allocation fails.
165+
*/
166+
public static Tensor zeros(@NonNull long[] shape, @NonNull ScalarType dtype) {
167+
if (shape == null) {
168+
throw new IllegalArgumentException("Shape cannot be null.");
169+
}
170+
if (dtype == null) {
171+
throw new IllegalArgumentException("Dtype cannot be null.");
172+
}
173+
long nativeHandle = nativeZeros(shape, dtype.getValue());
174+
if (nativeHandle == 0) {
175+
throw new IllegalArgumentException("Failed to create native Tensor with zeros.");
176+
}
177+
return new Tensor(nativeHandle);
178+
}
179+
180+
/**
181+
* Releases the native resources associated with this tensor.
182+
* After this method is called, the Tensor object becomes invalid.
183+
* This method is automatically called when the Tensor is used in a try-with-resources statement.
184+
*/
185+
@Override
186+
public void close() {
187+
if (nativeHandle != 0) {
188+
nativeRelease(nativeHandle);
189+
nativeHandle = 0;
190+
}
191+
}
192+
193+
// Static initializer to load the JNI library.
194+
// In a typical Android application, the JNI library might be loaded
195+
// once in a higher-level entry point (e.g., Application class or a Module class).
196+
// This block ensures the library is loaded if not already.
197+
static {
198+
try {
199+
System.loadLibrary("executorch_android_jni");
200+
} catch (UnsatisfiedLinkError e) {
201+
System.err.println("Failed to load native library 'executorch_android_jni': " + e.getMessage());
202+
// For a core library class like Tensor, if native functionality is essential,
203+
// it's appropriate to rethrow the error to indicate a critical setup failure.
204+
throw e;
205+
}
206+
}
207+
}

0 commit comments

Comments
 (0)