Skip to content

Commit 7ca14ec

Browse files
committed
fix: Resolve issue #15125
1 parent 1a5eaec commit 7ca14ec

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#include <jni.h>
2+
#include <vector>
3+
#include <numeric>
4+
#include <stdexcept>
5+
#include <functional> // For std::function
6+
7+
// Executorch headers
8+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
9+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
10+
#include <executorch/runtime/executor/tensor_wrapper.h> // For ManagedTensor
11+
#include <executorch/runtime/platform/runtime.h> // For ET_LOG
12+
13+
// Assumed new tensor factory header
14+
// This file is expected to define make_ones_tensor_ptr and make_zeros_tensor_ptr.
15+
// These functions are assumed to return a `executorch::runtime::ManagedTensor*`
16+
// (a raw pointer to a dynamically allocated ManagedTensor) which the Java side
17+
// will take ownership of and be responsible for deallocating via a corresponding
18+
// native deallocation method (e.g., nativeDeleteTensor).
19+
#include <executorch/extension/tensor/tensor_factory.h>
20+
21+
namespace executorch {
22+
namespace android {
23+
namespace jni {
24+
25+
// JniTensorUtil provides utility functions for converting JNI types and
26+
// handling common JNI error patterns in a robust way.
27+
class JniTensorUtil {
28+
public:
29+
// Converts a jlongArray (Java long[]) to a C++ exec_aten::Tensor::SizesType.
30+
// If an error occurs (e.g., OOM, negative dimension), it throws a Java exception
31+
// and returns an empty vector. The caller must check env->ExceptionCheck().
32+
static exec_aten::Tensor::SizesType jlongArrayToSizesType(JNIEnv* env, jlongArray jshape) {
33+
if (jshape == nullptr) {
34+
// An empty shape typically represents a scalar tensor, which is valid.
35+
return {};
36+
}
37+
38+
jsize length = env->GetArrayLength(jshape);
39+
exec_aten::Tensor::SizesType sizes(length);
40+
jlong* elements = env->GetLongArrayElements(jshape, nullptr);
41+
42+
if (elements == nullptr) {
43+
// GetLongArrayElements can return nullptr on OOM.
44+
env->ThrowNew(env->FindClass("java/lang/OutOfMemoryError"), "Failed to get jlongArray elements for shape.");
45+
return {};
46+
}
47+
48+
for (int i = 0; i < length; ++i) {
49+
if (elements[i] < 0) {
50+
// Release elements before throwing to prevent resource leak.
51+
env->ReleaseLongArrayElements(jshape, elements, JNI_ABORT);
52+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Tensor dimensions must be non-negative.");
53+
return {};
54+
}
55+
sizes[i] = static_cast<exec_aten::Tensor::SizesType::value_type>(elements[i]);
56+
}
57+
// Release elements without copying back changes (JNI_ABORT) as we only read.
58+
env->ReleaseLongArrayElements(jshape, elements, JNI_ABORT);
59+
return sizes;
60+
}
61+
62+
// Converts a jint (Java ScalarType ordinal) to a C++ exec_aten::ScalarType.
63+
// This assumes a direct ordinal mapping between Java and C++ enums.
64+
// Throws a Java IllegalArgumentException if the ordinal is out of range.
65+
static exec_aten::ScalarType jintToScalarType(JNIEnv* env, jint jscalarType) {
66+
if (jscalarType < 0 || jscalarType >= static_cast<jint>(exec_aten::ScalarType::NumOptions)) {
67+
ET_LOG(
68+
ERROR,
69+
"Invalid ScalarType ordinal value received: %d. Expected range [0, %d).",
70+
jscalarType,
71+
static_cast<jint>(exec_aten::ScalarType::NumOptions));
72+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Invalid ScalarType ordinal value.");
73+
return exec_aten::ScalarType::Undefined; // Sentinel for error
74+
}
75+
return static_cast<exec_aten::ScalarType>(jscalarType);
76+
}
77+
78+
// A generic helper function to create a native tensor using a provided C++ factory.
79+
// It encapsulates JNI type conversions, robust C++ exception handling, and
80+
// propagates errors as Java exceptions.
81+
static jlong createTensorFromFactory(
82+
JNIEnv* env,
83+
jlongArray jshape,
84+
jint jscalarType,
85+
// The factory function takes sizes and dtype, and returns a ManagedTensor*.
86+
std::function<executorch::runtime::ManagedTensor*(const exec_aten::Tensor::SizesType&, exec_aten::ScalarType)> factory_func) {
87+
try {
88+
// 1. Convert Java shape array to C++ vector.
89+
exec_aten::Tensor::SizesType sizes = JniTensorUtil::jlongArrayToSizesType(env, jshape);
90+
if (env->ExceptionCheck()) {
91+
// An exception was thrown by jlongArrayToSizesType; return immediately.
92+
return 0;
93+
}
94+
95+
// 2. Convert Java scalar type ordinal to C++ enum.
96+
exec_aten::ScalarType dtype = JniTensorUtil::jintToScalarType(env, jscalarType);
97+
if (env->ExceptionCheck()) {
98+
// An exception was thrown by jintToScalarType; return immediately.
99+
return 0;
100+
}
101+
// Additional safeguard: if dtype is Undefined due to an internal logic error
102+
// not caught by the explicit range check in jintToScalarType.
103+
if (dtype == exec_aten::ScalarType::Undefined) {
104+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Attempted to create tensor with an Undefined ScalarType.");
105+
return 0;
106+
}
107+
108+
// 3. Call the C++ tensor factory function.
109+
executorch::runtime::ManagedTensor* managed_tensor_ptr = factory_func(sizes, dtype);
110+
111+
// 4. Return the raw pointer to Java.
112+
// The Java side is responsible for managing the lifecycle of this native
113+
// resource (i.e., calling a corresponding nativeDelete method when done).
114+
return reinterpret_cast<jlong>(managed_tensor_ptr);
115+
116+
} catch (const std::bad_alloc& e) {
117+
ET_LOG(ERROR, "JNI: Out of memory during tensor creation: %s", e.what());
118+
env->ThrowNew(env->FindClass("java/lang/OutOfMemoryError"), e.what());
119+
return 0;
120+
} catch (const std::exception& e) {
121+
// Catching standard C++ exceptions (e.g., from factory function if it allocates and fails).
122+
ET_LOG(ERROR, "JNI: C++ exception during tensor creation: %s", e.what());
123+
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
124+
return 0;
125+
} catch (...) {
126+
// Catching any other unknown C++ exceptions.
127+
ET_LOG(ERROR, "JNI: Unknown C++ exception during tensor creation.");
128+
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), "Unknown C++ exception during tensor creation.");
129+
return 0;
130+
}
131+
}
132+
};
133+
134+
} // namespace jni
135+
} // namespace android
136+
} // namespace executorch
137+
138+
// JNI function to create a new tensor filled with ones.
139+
// It assumes the Java class path is `org.pytorch.executorch.Tensor`.
140+
// The 'clazz' parameter is required by JNI signature but is unused in this implementation.
141+
extern "C" JNIEXPORT jlong JNICALL Java_org_pytorch_executorch_Tensor_nativeNewOnesTensor(
142+
JNIEnv* env,
143+
jclass /* clazz */, // Marked as unused to suppress warnings
144+
jlongArray jshape,
145+
jint jscalarType) {
146+
return executorch::android::jni::JniTensorUtil::createTensorFromFactory(
147+
env,
148+
jshape,
149+
jscalarType,
150+
[](const executorch::exec_aten::Tensor::SizesType& sizes, executorch::exec_aten::ScalarType dtype) {
151+
return executorch::extension::tensor::make_ones_tensor_ptr(sizes, dtype);
152+
});
153+
}
154+
155+
// JNI function to create a new tensor filled with zeros.
156+
// It assumes the Java class path is `org.pytorch.executorch.Tensor`.
157+
// The 'clazz' parameter is required by JNI signature but is unused in this implementation.
158+
extern "C" JNIEXPORT jlong JNICALL Java_org_pytorch_executorch_Tensor_nativeNewZerosTensor(
159+
JNIEnv* env,
160+
jclass /* clazz */, // Marked as unused to suppress warnings
161+
jlongArray jshape,
162+
jint jscalarType) {
163+
return executorch::android::jni::JniTensorUtil::createTensorFromFactory(
164+
env,
165+
jshape,
166+
jscalarType,
167+
[](const executorch::exec_aten::Tensor::SizesType& sizes, executorch::exec_aten::ScalarType dtype) {
168+
return executorch::extension::tensor::make_zeros_tensor_ptr(sizes, dtype);
169+
});
170+
}

0 commit comments

Comments
 (0)