1+ #include  < jni.h> 
2+ #include  < vector> 
3+ #include  < numeric> 
4+ #include  < stdexcept> 
5+ #include  < functional>   //  For std::function
6+ 
7+ //  Executorch headers
8+ #include  < executorch/runtime/core/exec_aten/exec_aten.h> 
9+ #include  < executorch/runtime/core/exec_aten/util/scalar_type_util.h> 
10+ #include  < executorch/runtime/executor/tensor_wrapper.h>   //  For ManagedTensor
11+ #include  < executorch/runtime/platform/runtime.h>   //  For ET_LOG
12+ 
13+ //  Assumed new tensor factory header
14+ //  This file is expected to define make_ones_tensor_ptr and make_zeros_tensor_ptr.
15+ //  These functions are assumed to return a `executorch::runtime::ManagedTensor*`
16+ //  (a raw pointer to a dynamically allocated ManagedTensor) which the Java side
17+ //  will take ownership of and be responsible for deallocating via a corresponding
18+ //  native deallocation method (e.g., nativeDeleteTensor).
19+ #include  < executorch/extension/tensor/tensor_factory.h> 
20+ 
21+ namespace  executorch  {
22+ namespace  android  {
23+ namespace  jni  {
24+ 
25+ //  JniTensorUtil provides utility functions for converting JNI types and
26+ //  handling common JNI error patterns in a robust way.
27+ class  JniTensorUtil  {
28+ public: 
29+     //  Converts a jlongArray (Java long[]) to a C++ exec_aten::Tensor::SizesType.
30+     //  If an error occurs (e.g., OOM, negative dimension), it throws a Java exception
31+     //  and returns an empty vector. The caller must check env->ExceptionCheck().
32+     static  exec_aten::Tensor::SizesType jlongArrayToSizesType (JNIEnv* env, jlongArray jshape) {
33+         if  (jshape == nullptr ) {
34+             //  An empty shape typically represents a scalar tensor, which is valid.
35+             return  {};
36+         }
37+ 
38+         jsize length = env->GetArrayLength (jshape);
39+         exec_aten::Tensor::SizesType sizes (length);
40+         jlong* elements = env->GetLongArrayElements (jshape, nullptr );
41+ 
42+         if  (elements == nullptr ) {
43+             //  GetLongArrayElements can return nullptr on OOM.
44+             env->ThrowNew (env->FindClass (" java/lang/OutOfMemoryError"  ), " Failed to get jlongArray elements for shape."  );
45+             return  {};
46+         }
47+ 
48+         for  (int  i = 0 ; i < length; ++i) {
49+             if  (elements[i] < 0 ) {
50+                 //  Release elements before throwing to prevent resource leak.
51+                 env->ReleaseLongArrayElements (jshape, elements, JNI_ABORT);
52+                 env->ThrowNew (env->FindClass (" java/lang/IllegalArgumentException"  ), " Tensor dimensions must be non-negative."  );
53+                 return  {};
54+             }
55+             sizes[i] = static_cast <exec_aten::Tensor::SizesType::value_type>(elements[i]);
56+         }
57+         //  Release elements without copying back changes (JNI_ABORT) as we only read.
58+         env->ReleaseLongArrayElements (jshape, elements, JNI_ABORT);
59+         return  sizes;
60+     }
61+ 
62+     //  Converts a jint (Java ScalarType ordinal) to a C++ exec_aten::ScalarType.
63+     //  This assumes a direct ordinal mapping between Java and C++ enums.
64+     //  Throws a Java IllegalArgumentException if the ordinal is out of range.
65+     static  exec_aten::ScalarType jintToScalarType (JNIEnv* env, jint jscalarType) {
66+         if  (jscalarType < 0  || jscalarType >= static_cast <jint>(exec_aten::ScalarType::NumOptions)) {
67+             ET_LOG (
68+                 ERROR,
69+                 " Invalid ScalarType ordinal value received: %d. Expected range [0, %d)."  ,
70+                 jscalarType,
71+                 static_cast <jint>(exec_aten::ScalarType::NumOptions));
72+             env->ThrowNew (env->FindClass (" java/lang/IllegalArgumentException"  ), " Invalid ScalarType ordinal value."  );
73+             return  exec_aten::ScalarType::Undefined; //  Sentinel for error
74+         }
75+         return  static_cast <exec_aten::ScalarType>(jscalarType);
76+     }
77+ 
78+     //  A generic helper function to create a native tensor using a provided C++ factory.
79+     //  It encapsulates JNI type conversions, robust C++ exception handling, and
80+     //  propagates errors as Java exceptions.
81+     static  jlong createTensorFromFactory (
82+         JNIEnv* env,
83+         jlongArray jshape,
84+         jint jscalarType,
85+         //  The factory function takes sizes and dtype, and returns a ManagedTensor*.
86+         std::function<executorch::runtime::ManagedTensor*(const  exec_aten::Tensor::SizesType&, exec_aten::ScalarType)> factory_func) {
87+         try  {
88+             //  1. Convert Java shape array to C++ vector.
89+             exec_aten::Tensor::SizesType sizes = JniTensorUtil::jlongArrayToSizesType (env, jshape);
90+             if  (env->ExceptionCheck ()) {
91+                 //  An exception was thrown by jlongArrayToSizesType; return immediately.
92+                 return  0 ;
93+             }
94+ 
95+             //  2. Convert Java scalar type ordinal to C++ enum.
96+             exec_aten::ScalarType dtype = JniTensorUtil::jintToScalarType (env, jscalarType);
97+             if  (env->ExceptionCheck ()) {
98+                 //  An exception was thrown by jintToScalarType; return immediately.
99+                 return  0 ;
100+             }
101+             //  Additional safeguard: if dtype is Undefined due to an internal logic error
102+             //  not caught by the explicit range check in jintToScalarType.
103+             if  (dtype == exec_aten::ScalarType::Undefined) {
104+                 env->ThrowNew (env->FindClass (" java/lang/IllegalArgumentException"  ), " Attempted to create tensor with an Undefined ScalarType."  );
105+                 return  0 ;
106+             }
107+ 
108+             //  3. Call the C++ tensor factory function.
109+             executorch::runtime::ManagedTensor* managed_tensor_ptr = factory_func (sizes, dtype);
110+ 
111+             //  4. Return the raw pointer to Java.
112+             //  The Java side is responsible for managing the lifecycle of this native
113+             //  resource (i.e., calling a corresponding nativeDelete method when done).
114+             return  reinterpret_cast <jlong>(managed_tensor_ptr);
115+ 
116+         } catch  (const  std::bad_alloc& e) {
117+             ET_LOG (ERROR, " JNI: Out of memory during tensor creation: %s"  , e.what ());
118+             env->ThrowNew (env->FindClass (" java/lang/OutOfMemoryError"  ), e.what ());
119+             return  0 ;
120+         } catch  (const  std::exception& e) {
121+             //  Catching standard C++ exceptions (e.g., from factory function if it allocates and fails).
122+             ET_LOG (ERROR, " JNI: C++ exception during tensor creation: %s"  , e.what ());
123+             env->ThrowNew (env->FindClass (" java/lang/RuntimeException"  ), e.what ());
124+             return  0 ;
125+         } catch  (...) {
126+             //  Catching any other unknown C++ exceptions.
127+             ET_LOG (ERROR, " JNI: Unknown C++ exception during tensor creation."  );
128+             env->ThrowNew (env->FindClass (" java/lang/RuntimeException"  ), " Unknown C++ exception during tensor creation."  );
129+             return  0 ;
130+         }
131+     }
132+ };
133+ 
134+ } //  namespace jni
135+ } //  namespace android
136+ } //  namespace executorch
137+ 
138+ //  JNI function to create a new tensor filled with ones.
139+ //  It assumes the Java class path is `org.pytorch.executorch.Tensor`.
140+ //  The 'clazz' parameter is required by JNI signature but is unused in this implementation.
141+ extern  " C"   JNIEXPORT jlong JNICALL Java_org_pytorch_executorch_Tensor_nativeNewOnesTensor (
142+     JNIEnv* env,
143+     jclass /*  clazz */  , //  Marked as unused to suppress warnings
144+     jlongArray jshape,
145+     jint jscalarType) {
146+     return  executorch::android::jni::JniTensorUtil::createTensorFromFactory (
147+         env,
148+         jshape,
149+         jscalarType,
150+         [](const  executorch::exec_aten::Tensor::SizesType& sizes, executorch::exec_aten::ScalarType dtype) {
151+             return  executorch::extension::tensor::make_ones_tensor_ptr (sizes, dtype);
152+         });
153+ }
154+ 
155+ //  JNI function to create a new tensor filled with zeros.
156+ //  It assumes the Java class path is `org.pytorch.executorch.Tensor`.
157+ //  The 'clazz' parameter is required by JNI signature but is unused in this implementation.
158+ extern  " C"   JNIEXPORT jlong JNICALL Java_org_pytorch_executorch_Tensor_nativeNewZerosTensor (
159+     JNIEnv* env,
160+     jclass /*  clazz */  , //  Marked as unused to suppress warnings
161+     jlongArray jshape,
162+     jint jscalarType) {
163+     return  executorch::android::jni::JniTensorUtil::createTensorFromFactory (
164+         env,
165+         jshape,
166+         jscalarType,
167+         [](const  executorch::exec_aten::Tensor::SizesType& sizes, executorch::exec_aten::ScalarType dtype) {
168+             return  executorch::extension::tensor::make_zeros_tensor_ptr (sizes, dtype);
169+         });
170+ }
0 commit comments