Skip to content

Commit

Permalink
Merge pull request #97 from pickleburger/primitive_array
Browse files Browse the repository at this point in the history
support primitive arrays in to_rust
  • Loading branch information
astonbitecode authored Apr 30, 2024
2 parents d287785 + a9c05cc commit 612a762
Show file tree
Hide file tree
Showing 4 changed files with 575 additions and 4 deletions.
40 changes: 40 additions & 0 deletions rust/src/api/invocation_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ impl InvocationArg {
class_name: class_name.to_string(),
serialized: false,
})
} else if let Some(a) = arg_any.downcast_ref::<u16>() {
Ok(InvocationArg::RustBasic {
instance: Instance::new(
jni_utils::global_jobject_from_u16(a, jni_env)?,
class_name,
)?,
class_name: class_name.to_string(),
serialized: false,
})
} else if let Some(a) = arg_any.downcast_ref::<i32>() {
Ok(InvocationArg::RustBasic {
instance: Instance::new(
Expand Down Expand Up @@ -446,6 +455,30 @@ impl<'a> TryFrom<&'a [i16]> for InvocationArg {
}
}

impl TryFrom<u16> for InvocationArg {
type Error = errors::J4RsError;
fn try_from(arg: u16) -> errors::Result<InvocationArg> {
InvocationArg::new_2(
&arg,
JavaClass::Character.into(),
cache::get_thread_local_env()?,
)
}
}

impl<'a> TryFrom<&'a [u16]> for InvocationArg {
type Error = errors::J4RsError;
fn try_from(vec: &'a [u16]) -> errors::Result<InvocationArg> {
let args: errors::Result<Vec<InvocationArg>> = vec
.iter()
.map(|elem| InvocationArg::try_from(elem))
.collect();
let res =
Jvm::do_create_java_list(cache::get_thread_local_env()?, cache::J4RS_ARRAY, &args?);
Ok(InvocationArg::from(res?))
}
}

impl TryFrom<i32> for InvocationArg {
type Error = errors::J4RsError;
fn try_from(arg: i32) -> errors::Result<InvocationArg> {
Expand Down Expand Up @@ -592,6 +625,13 @@ impl<'a> TryFrom<&'a i16> for InvocationArg {
}
}

impl<'a> TryFrom<&'a u16> for InvocationArg {
type Error = errors::J4RsError;
fn try_from(arg: &'a u16) -> errors::Result<InvocationArg> {
InvocationArg::new_2(arg, JavaClass::Character.into(), cache::get_thread_local_env()?)
}
}

impl<'a, 'b> TryFrom<&'a i32> for InvocationArg {
type Error = errors::J4RsError;
fn try_from(arg: &'a i32) -> errors::Result<InvocationArg> {
Expand Down
245 changes: 245 additions & 0 deletions rust/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ pub(crate) const PRIMITIVE_LONG: &'static str = "long";
pub(crate) const PRIMITIVE_FLOAT: &'static str = "float";
pub(crate) const PRIMITIVE_DOUBLE: &'static str = "double";
pub(crate) const PRIMITIVE_CHAR: &'static str = "char";

pub(crate) const PRIMITIVE_BOOLEAN_ARRAY: &'static str = "[Z";
pub(crate) const PRIMITIVE_BYTE_ARRAY: &'static str = "[B";
pub(crate) const PRIMITIVE_SHORT_ARRAY: &'static str = "[S";
pub(crate) const PRIMITIVE_INT_ARRAY: &'static str = "[I";
pub(crate) const PRIMITIVE_LONG_ARRAY: &'static str = "[J";
pub(crate) const PRIMITIVE_FLOAT_ARRAY: &'static str = "[F";
pub(crate) const PRIMITIVE_DOUBLE_ARRAY: &'static str = "[D";
pub(crate) const PRIMITIVE_CHAR_ARRAY: &'static str = "[C";

pub(crate) const CLASS_NATIVE_CALLBACK_TO_RUST_CHANNEL_SUPPORT: &'static str =
"org.astonbitecode.j4rs.api.invocation.NativeCallbackToRustChannelSupport";
pub(crate) const CLASS_J4RS_EVENT_HANDLER: &'static str =
Expand Down Expand Up @@ -249,6 +259,9 @@ impl Jvm {
let _ = cache::get_jni_call_short_method().or_else(|| {
cache::set_jni_call_short_method(Some((**jni_environment).v1_6.CallShortMethod))
});
let _ = cache::get_jni_call_char_method().or_else(|| {
cache::set_jni_call_char_method(Some((**jni_environment).v1_6.CallCharMethod))
});
let _ = cache::get_jni_call_int_method().or_else(|| {
cache::set_jni_call_int_method(Some((**jni_environment).v1_6.CallIntMethod))
});
Expand All @@ -269,6 +282,91 @@ impl Jvm {
(**jni_environment).v1_6.CallStaticObjectMethod,
))
});
let _ = cache::get_jni_get_array_length().or_else(|| {
cache::set_jni_get_array_length(Some(
(**jni_environment).v1_6.GetArrayLength,
))
});
let _ = cache::get_jni_get_byte_array_elements().or_else(|| {
cache::set_jni_get_byte_array_elements(Some(
(**jni_environment).v1_6.GetByteArrayElements,
))
});
let _ = cache::get_jni_release_byte_array_elements().or_else(|| {
cache::set_jni_release_byte_array_elements(Some(
(**jni_environment).v1_6.ReleaseByteArrayElements,
))
});
let _ = cache::get_jni_get_short_array_elements().or_else(|| {
cache::set_jni_get_short_array_elements(Some(
(**jni_environment).v1_6.GetShortArrayElements,
))
});
let _ = cache::get_jni_release_short_array_elements().or_else(|| {
cache::set_jni_release_short_array_elements(Some(
(**jni_environment).v1_6.ReleaseShortArrayElements,
))
});
let _ = cache::get_jni_get_char_array_elements().or_else(|| {
cache::set_jni_get_char_array_elements(Some(
(**jni_environment).v1_6.GetCharArrayElements,
))
});
let _ = cache::get_jni_release_char_array_elements().or_else(|| {
cache::set_jni_release_char_array_elements(Some(
(**jni_environment).v1_6.ReleaseCharArrayElements,
))
});
let _ = cache::get_jni_get_int_array_elements().or_else(|| {
cache::set_jni_get_int_array_elements(Some(
(**jni_environment).v1_6.GetIntArrayElements,
))
});
let _ = cache::get_jni_release_int_array_elements().or_else(|| {
cache::set_jni_release_int_array_elements(Some(
(**jni_environment).v1_6.ReleaseIntArrayElements,
))
});
let _ = cache::get_jni_get_long_array_elements().or_else(|| {
cache::set_jni_get_long_array_elements(Some(
(**jni_environment).v1_6.GetLongArrayElements,
))
});
let _ = cache::get_jni_release_long_array_elements().or_else(|| {
cache::set_jni_release_long_array_elements(Some(
(**jni_environment).v1_6.ReleaseLongArrayElements,
))
});
let _ = cache::get_jni_get_float_array_elements().or_else(|| {
cache::set_jni_get_float_array_elements(Some(
(**jni_environment).v1_6.GetFloatArrayElements,
))
});
let _ = cache::get_jni_release_float_array_elements().or_else(|| {
cache::set_jni_release_float_array_elements(Some(
(**jni_environment).v1_6.ReleaseFloatArrayElements,
))
});
let _ = cache::get_jni_get_double_array_elements().or_else(|| {
cache::set_jni_get_double_array_elements(Some(
(**jni_environment).v1_6.GetDoubleArrayElements,
))
});
let _ = cache::get_jni_release_double_array_elements().or_else(|| {
cache::set_jni_release_double_array_elements(Some(
(**jni_environment).v1_6.ReleaseDoubleArrayElements,
))
});
let _ = cache::get_jni_get_boolean_array_elements().or_else(|| {
cache::set_jni_get_boolean_array_elements(Some(
(**jni_environment).v1_6.GetBooleanArrayElements,
))
});
let _ = cache::get_jni_release_boolean_array_elements().or_else(|| {
cache::set_jni_release_boolean_array_elements(Some(
(**jni_environment).v1_6.ReleaseBooleanArrayElements,
))
});
let _ = cache::get_jni_new_object_array().or_else(|| {
cache::set_jni_new_object_array(Some((**jni_environment).v1_6.NewObjectArray))
});
Expand Down Expand Up @@ -1167,6 +1265,10 @@ impl Jvm {
&& (JavaClass::Short.get_class_str() == class_name || PRIMITIVE_SHORT == class_name)
{
rust_box_from_java_object!(jni_utils::i16_from_jobject)
} else if t_type == TypeId::of::<u16>()
&& (JavaClass::Character.get_class_str() == class_name || PRIMITIVE_CHAR == class_name)
{
rust_box_from_java_object!(jni_utils::u16_from_jobject)
} else if t_type == TypeId::of::<i64>()
&& (JavaClass::Long.get_class_str() == class_name || PRIMITIVE_LONG == class_name)
{
Expand All @@ -1180,6 +1282,38 @@ impl Jvm {
|| PRIMITIVE_DOUBLE == class_name)
{
rust_box_from_java_object!(jni_utils::f64_from_jobject)
} else if t_type == TypeId::of::<Vec<i8>>()
&& PRIMITIVE_BYTE_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::i8_array_from_jobject)
} else if t_type == TypeId::of::<Vec<i16>>()
&& PRIMITIVE_SHORT_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::i16_array_from_jobject)
} else if t_type == TypeId::of::<Vec<u16>>()
&& PRIMITIVE_CHAR_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::u16_array_from_jobject)
} else if t_type == TypeId::of::<Vec<i32>>()
&& PRIMITIVE_INT_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::i32_array_from_jobject)
} else if t_type == TypeId::of::<Vec<i64>>()
&& PRIMITIVE_LONG_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::i64_array_from_jobject)
} else if t_type == TypeId::of::<Vec<f32>>()
&& PRIMITIVE_FLOAT_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::f32_array_from_jobject)
} else if t_type == TypeId::of::<Vec<f64>>()
&& PRIMITIVE_DOUBLE_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::f64_array_from_jobject)
} else if t_type == TypeId::of::<Vec<bool>>()
&& PRIMITIVE_BOOLEAN_ARRAY == class_name
{
rust_box_from_java_object!(jni_utils::boolean_array_from_jobject)
} else {
Ok(Box::new(self.to_rust_deserialized(instance)?))
}
Expand Down Expand Up @@ -1999,6 +2133,102 @@ mod api_unit_tests {
Ok(())
}

#[test]
fn test_byte_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<i8> = vec![-3_i8, 7_i8, 8_i8];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_BYTE, &ia)?;
let rust_value_from_java: Vec<i8> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_short_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<i16> = vec![-3_i16, 7_i16, 10000_i16];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_SHORT, &ia)?;
let rust_value_from_java: Vec<i16> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_char_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<u16> = vec![3_u16, 7_u16, 10000_u16];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_CHAR, &ia)?;
let rust_value_from_java: Vec<u16> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_int_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<i32> = vec![-100_000, -1_000_000, 1_000_000];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_INT, &ia)?;
let rust_value_from_java: Vec<i32> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_long_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<i64> = vec![-100_000, -1_000_000, 1_000_000];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_LONG, &ia)?;
let rust_value_from_java: Vec<i64> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_float_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<f32> = vec![3_f32, 7.5_f32, -1000.5_f32];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_FLOAT, &ia)?;
let rust_value_from_java: Vec<f32> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_double_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<f64> = vec![3_f64, 7.5_f64, -1000.5_f64];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_DOUBLE, &ia)?;
let rust_value_from_java: Vec<f64> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_boolean_array_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: Vec<bool> = vec![false, true, false];
let ia: Vec<_> = rust_value.iter().map(|x| InvocationArg::try_from(x).unwrap().into_primitive().unwrap()).collect();
let java_instance = jvm.create_java_array(PRIMITIVE_BOOLEAN, &ia)?;
let rust_value_from_java: Vec<bool> = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_int_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
Expand Down Expand Up @@ -2044,6 +2274,21 @@ mod api_unit_tests {
Ok(())
}

#[test]
fn test_char_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
let rust_value: u16 = 3;
let ia = InvocationArg::try_from(rust_value)?.into_primitive()?;
let java_instance = jvm.create_instance(CLASS_CHARACTER, &[ia])?;
let java_primitive_instance = jvm.invoke(&java_instance, "charValue", InvocationArg::empty())?;
let rust_value_from_java: u16 = jvm.to_rust(java_instance)?;
assert_eq!(rust_value_from_java, rust_value);
let rust_value_from_java: u16 = jvm.to_rust(java_primitive_instance)?;
assert_eq!(rust_value_from_java, rust_value);

Ok(())
}

#[test]
fn test_long_to_rust() -> errors::Result<()> {
let jvm = create_tests_jvm()?;
Expand Down
Loading

0 comments on commit 612a762

Please sign in to comment.