From 4723385b4cd29e977f5fbedee93d5ac174fe0458 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Tue, 3 Dec 2024 14:08:45 +0100 Subject: [PATCH 1/7] Start implementing generic wrapper types --- rustler/src/types/mod.rs | 4 +++ rustler/src/types/reference.rs | 28 ++++++++++--------- rustler/src/types/wrapper.rs | 50 ++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 13 deletions(-) create mode 100644 rustler/src/types/wrapper.rs diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index c7b72005..01947b79 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -1,5 +1,9 @@ use crate::{Env, Error, NifResult, Term}; +#[macro_use] +mod wrapper; +pub(crate) use self::wrapper::wrapper; + #[macro_use] pub mod atom; pub use crate::types::atom::Atom; diff --git a/rustler/src/types/reference.rs b/rustler/src/types/reference.rs index 9514f854..ce48260e 100644 --- a/rustler/src/types/reference.rs +++ b/rustler/src/types/reference.rs @@ -1,22 +1,24 @@ use std::ops::Deref; -use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; +use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; use crate::sys::enif_make_ref; -/// Wrapper for BEAM reference terms. -#[derive(PartialEq, Eq, Clone, Copy)] -pub struct Reference<'a>(Term<'a>); +wrapper!(Reference, TermType::Ref); -impl Reference<'_> { - /// Returns a representation of self in the given Env. - /// - /// If the term is already is in the provided env, it will be directly returned. Otherwise - /// the term will be copied over. - pub fn in_env<'b>(&self, env: Env<'b>) -> Reference<'b> { - Reference(self.0.in_env(env)) - } -} +/// Wrapper for BEAM reference terms. +// #[derive(PartialEq, Eq, Clone, Copy)] +// pub struct Reference<'a>(Term<'a>); +// +// impl<'a> Reference<'a> { +// /// Returns a representation of self in the given Env. +// /// +// /// If the term is already is in the provided env, it will be directly returned. Otherwise +// /// the term will be copied over. +// pub fn in_env<'b>(&self, env: Env<'b>) -> Reference<'b> { +// Reference(self.0.in_env(env)) +// } +// } impl<'a> Deref for Reference<'a> { type Target = Term<'a>; diff --git a/rustler/src/types/wrapper.rs b/rustler/src/types/wrapper.rs new file mode 100644 index 00000000..ff843563 --- /dev/null +++ b/rustler/src/types/wrapper.rs @@ -0,0 +1,50 @@ +use crate::TermType; + +struct WrapperError; + +pub trait Wrapper<'a>: Sized { + const WRAPPED_TYPE: TermType; + + fn wrap<'b>(term: Term<'a>) -> Result { + if term.get_type() == Self::WRAPPED_TYPE { + unsafe { Ok(Self::wrap_unsafe(term)) } + } else { + Err(WrapperError) + } + } + + fn unwrap(&self) -> Term<'a>; + + unsafe fn wrap_unsafe(term: Term<'a>) -> Self; + + /// Returns a representation of self in the given Env. + /// + /// If the term is already is in the provided env, it will be directly returned. Otherwise + /// the term will be copied over. + fn in_env<'b>(&self, env: Env<'b>) -> impl Wrapper<'b, WRAPPED_TYPE = Self::WrappedType> { + self.unwrap().in_env(env) + } +} + +macro_rules! wrapper { + ($name:ident, $term_type:path) => { + #[derive(PartialEq, Eq, Clone, Copy)] + pub struct $name<'a>(Term<'a>); + + impl<'a> $crate::types::wrapper::Wrapper<'a> for $name<'a> { + const WrappedType: TermType = $term_type; + + unsafe fn wrap_unsafe(term: Term<'a>) -> Self { + $name(term) + } + + fn unwrap(&self) -> Term<'a> { + self.0 + } + } + }; +} + +pub(crate) use wrapper; + +use crate::{Env, Term}; From e5e142fcd20979f309c917498b550b23f27c5a1e Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Wed, 18 Dec 2024 09:06:15 +0100 Subject: [PATCH 2/7] Current state --- rustler/src/lib.rs | 6 +- rustler/src/types/elixir_struct.rs | 15 +- rustler/src/types/list.rs | 204 --------------- rustler/src/types/mod.rs | 1 - rustler/src/types/reference.rs | 61 +---- rustler/src/types/wrapper.rs | 50 ---- rustler/src/wrapped_types/list.rs | 247 ++++++++++++++++++ rustler/src/{types => wrapped_types}/map.rs | 62 ++++- rustler/src/wrapped_types/mod.rs | 7 + rustler/src/{types => wrapped_types}/tuple.rs | 66 +++-- rustler/src/wrapped_types/wrapper.rs | 106 ++++++++ rustler/src/wrapper/list.rs | 42 +-- rustler/src/wrapper/tuple.rs | 13 +- .../native/rustler_test/src/test_env.rs | 4 +- 14 files changed, 491 insertions(+), 393 deletions(-) delete mode 100644 rustler/src/types/list.rs delete mode 100644 rustler/src/types/wrapper.rs create mode 100644 rustler/src/wrapped_types/list.rs rename rustler/src/{types => wrapped_types}/map.rs (86%) create mode 100644 rustler/src/wrapped_types/mod.rs rename rustler/src/{types => wrapped_types}/tuple.rs (64%) create mode 100644 rustler/src/wrapped_types/wrapper.rs diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index 159e7dc7..fc3e7047 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -35,10 +35,14 @@ mod alloc; pub mod types; mod term; +mod wrapped_types; +pub use crate::wrapped_types::{ + ListIterator, Map +}; pub use crate::term::Term; pub use crate::types::{ - Atom, Binary, Decoder, Encoder, ErlOption, ListIterator, LocalPid, MapIterator, NewBinary, + Atom, Binary, Decoder, Encoder, ErlOption, LocalPid, MapIterator, NewBinary, OwnedBinary, Reference, }; diff --git a/rustler/src/types/elixir_struct.rs b/rustler/src/types/elixir_struct.rs index d0e60dc3..7bde54a0 100644 --- a/rustler/src/types/elixir_struct.rs +++ b/rustler/src/types/elixir_struct.rs @@ -8,19 +8,22 @@ //! `#[module = "Elixir.TheStructModule"]`. use super::atom::{self, Atom}; -use super::map::map_new; -use crate::{Env, NifResult, Term}; +use super::map::Map; +use crate::{Env, Error, NifResult, Term}; pub fn get_ex_struct_name(map: Term) -> NifResult { // In an Elixir struct the value in the __struct__ field is always an atom. - map.map_get(atom::__struct__()).and_then(Atom::from_term) + let map: Map<'_> = map.try_into()?; + map.get(atom::__struct__()) + .ok_or(Error::BadArg) + .and_then(Atom::from_term) } -pub fn make_ex_struct<'a>(env: Env<'a>, struct_module: &str) -> NifResult> { - let map = map_new(env); +pub fn make_ex_struct<'a>(env: Env<'a>, struct_module: &str) -> NifResult> { + let map = env.new_map(); let struct_atom = atom::__struct__(); let module_atom = Atom::from_str(env, struct_module)?; - map.map_put(struct_atom, module_atom) + map.put(struct_atom, module_atom) } diff --git a/rustler/src/types/list.rs b/rustler/src/types/list.rs deleted file mode 100644 index 22048d4a..00000000 --- a/rustler/src/types/list.rs +++ /dev/null @@ -1,204 +0,0 @@ -//! Utilities used for working with erlang linked lists. -//! -//! Right now the only supported way to read lists are through the ListIterator. - -use crate::wrapper::{list, NIF_TERM}; -use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; - -/// Enables iteration over the items in the list. -/// -/// Although this behaves like a standard Rust iterator -/// ([book](https://doc.rust-lang.org/book/iterators.html) / -/// [docs](https://doc.rust-lang.org/std/iter/trait.Iterator.html)), there are a couple of tricky -/// parts to using it. -/// -/// Because the iterator is an iterator over `Term`s, you need to decode the terms before you -/// can do anything with them. -/// -/// ## Example -/// An easy way to decode all terms in a list, is to use the `.map()` function of the iterator, and -/// decode every entry in the list. This will produce an iterator of `Result`s, and will therefore -/// not be directly usable in the way you might immediately expect. -/// -/// For this case, the the `.collect()` function of rust iterators is useful, as it can lift -/// the `Result`s out of the list. (Contains extra type annotations for clarity) -/// -/// ``` -/// # use rustler::{Term, NifResult}; -/// # use rustler::types::list::ListIterator; -/// # fn list_iterator_example(list_term: Term) -> NifResult> { -/// let list_iterator: ListIterator = list_term.decode()?; -/// -/// let result: NifResult> = list_iterator -/// // Produces an iterator of NifResult -/// .map(|x| x.decode::()) -/// // Lifts each value out of the result. Returns Ok(Vec) if successful, the first error -/// // Error(Error) on failure. -/// .collect::>>(); -/// # result -/// # } -/// ``` -pub struct ListIterator<'a> { - term: Term<'a>, -} - -impl<'a> ListIterator<'a> { - fn new(term: Term<'a>) -> Option { - if term.is_list() { - let iter = ListIterator { term }; - Some(iter) - } else { - None - } - } -} - -impl<'a> Iterator for ListIterator<'a> { - type Item = Term<'a>; - - fn next(&mut self) -> Option> { - let env = self.term.get_env(); - let cell = unsafe { list::get_list_cell(env.as_c_arg(), self.term.as_c_arg()) }; - - match cell { - Some((head, tail)) => unsafe { - self.term = Term::new(self.term.get_env(), tail); - Some(Term::new(self.term.get_env(), head)) - }, - None => { - if self.term.is_empty_list() { - // We reached the end of the list, finish the iterator. - None - } else { - panic!("list iterator found improper list") - } - } - } - } -} - -impl<'a> Decoder<'a> for ListIterator<'a> { - fn decode(term: Term<'a>) -> NifResult { - match ListIterator::new(term) { - Some(iter) => Ok(iter), - None => Err(Error::BadArg), - } - } -} - -//impl<'a, T> Encoder for Iterator where T: Encoder { -// fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { -// let term_arr: Vec = -// self.map(|x| x.encode(env).as_c_arg()).collect(); -// } -//} -impl Encoder for Vec -where - T: Encoder, -{ - fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { - self.as_slice().encode(env) - } -} - -impl<'a, T> Decoder<'a> for Vec -where - T: Decoder<'a>, -{ - fn decode(term: Term<'a>) -> NifResult { - let iter: ListIterator = term.decode()?; - let res: NifResult = iter.map(|x| x.decode::()).collect(); - res - } -} - -impl Encoder for [T] -where - T: Encoder, -{ - fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { - let term_array: Vec = self.iter().map(|x| x.encode(env).as_c_arg()).collect(); - unsafe { Term::new(env, list::make_list(env.as_c_arg(), &term_array)) } - } -} - -impl Encoder for &[T] -where - T: Encoder, -{ - fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { - let term_array: Vec = self.iter().map(|x| x.encode(env).as_c_arg()).collect(); - unsafe { Term::new(env, list::make_list(env.as_c_arg(), &term_array)) } - } -} - -/// ## List terms -impl<'a> Term<'a> { - /// Returns a new empty list. - pub fn list_new_empty(env: Env<'a>) -> Term<'a> { - let list: &[u8] = &[]; - list.encode(env) - } - - /// Returns an iterator over a list term. - /// See documentation for ListIterator for more information. - /// - /// Returns None if the term is not a list. - pub fn into_list_iterator(self) -> NifResult> { - ListIterator::new(self).ok_or(Error::BadArg) - } - - /// Returns the length of a list term. - /// - /// Returns None if the term is not a list. - /// - /// ### Elixir equivalent - /// ```elixir - /// length(self_term) - /// ``` - pub fn list_length(self) -> NifResult { - unsafe { list::get_list_length(self.get_env().as_c_arg(), self.as_c_arg()) } - .ok_or(Error::BadArg) - } - - /// Unpacks a single cell at the head of a list term, - /// and returns the result as a tuple of (head, tail). - /// - /// Returns None if the term is not a list. - /// - /// ### Elixir equivalent - /// ```elixir - /// [head, tail] = self_term - /// {head, tail} - /// ``` - pub fn list_get_cell(self) -> NifResult<(Term<'a>, Term<'a>)> { - let env = self.get_env(); - unsafe { - list::get_list_cell(env.as_c_arg(), self.as_c_arg()) - .map(|(t1, t2)| (Term::new(env, t1), Term::new(env, t2))) - .ok_or(Error::BadArg) - } - } - - /// Makes a copy of the self list term and reverses it. - /// - /// Returns Err(Error::BadArg) if the term is not a list. - pub fn list_reverse(self) -> NifResult> { - let env = self.get_env(); - unsafe { - list::make_reverse_list(env.as_c_arg(), self.as_c_arg()) - .map(|t| Term::new(env, t)) - .ok_or(Error::BadArg) - } - } - - /// Adds `head` in a list cell with `self` as tail. - pub fn list_prepend(self, head: impl Encoder) -> Term<'a> { - let env = self.get_env(); - let head = head.encode(env); - unsafe { - let term = list::make_list_cell(env.as_c_arg(), head.as_c_arg(), self.as_c_arg()); - Term::new(env, term) - } - } -} diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index 01947b79..829cded8 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -2,7 +2,6 @@ use crate::{Env, Error, NifResult, Term}; #[macro_use] mod wrapper; -pub(crate) use self::wrapper::wrapper; #[macro_use] pub mod atom; diff --git a/rustler/src/types/reference.rs b/rustler/src/types/reference.rs index ce48260e..4d4873e0 100644 --- a/rustler/src/types/reference.rs +++ b/rustler/src/types/reference.rs @@ -1,66 +1,15 @@ -use std::ops::Deref; - -use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; +use crate::{Env, Term, TermType}; use crate::sys::enif_make_ref; -wrapper!(Reference, TermType::Ref); - -/// Wrapper for BEAM reference terms. -// #[derive(PartialEq, Eq, Clone, Copy)] -// pub struct Reference<'a>(Term<'a>); -// -// impl<'a> Reference<'a> { -// /// Returns a representation of self in the given Env. -// /// -// /// If the term is already is in the provided env, it will be directly returned. Otherwise -// /// the term will be copied over. -// pub fn in_env<'b>(&self, env: Env<'b>) -> Reference<'b> { -// Reference(self.0.in_env(env)) -// } -// } - -impl<'a> Deref for Reference<'a> { - type Target = Term<'a>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<'a> From> for Term<'a> { - fn from(term: Reference<'a>) -> Self { - term.0 - } -} - -impl<'a> TryFrom> for Reference<'a> { - type Error = Error; - - fn try_from(term: Term<'a>) -> Result { - if term.is_ref() { - Ok(Reference(term)) - } else { - Err(Error::BadArg) - } - } -} - -impl<'a> Decoder<'a> for Reference<'a> { - fn decode(term: Term<'a>) -> NifResult { - term.try_into() - } -} - -impl Encoder for Reference<'_> { - fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { - self.0.encode(env) - } +wrapper!{ + struct Reference(TermType::Ref) } impl<'a> Env<'a> { /// Create a new reference in this environment pub fn make_ref(self) -> Reference<'a> { - unsafe { Reference(Term::new(self, enif_make_ref(self.as_c_arg()))) } + let term = unsafe { Term::new(self, enif_make_ref(self.as_c_arg())) }; + unsafe { Reference::wrap_unchecked(term) } } } diff --git a/rustler/src/types/wrapper.rs b/rustler/src/types/wrapper.rs deleted file mode 100644 index ff843563..00000000 --- a/rustler/src/types/wrapper.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::TermType; - -struct WrapperError; - -pub trait Wrapper<'a>: Sized { - const WRAPPED_TYPE: TermType; - - fn wrap<'b>(term: Term<'a>) -> Result { - if term.get_type() == Self::WRAPPED_TYPE { - unsafe { Ok(Self::wrap_unsafe(term)) } - } else { - Err(WrapperError) - } - } - - fn unwrap(&self) -> Term<'a>; - - unsafe fn wrap_unsafe(term: Term<'a>) -> Self; - - /// Returns a representation of self in the given Env. - /// - /// If the term is already is in the provided env, it will be directly returned. Otherwise - /// the term will be copied over. - fn in_env<'b>(&self, env: Env<'b>) -> impl Wrapper<'b, WRAPPED_TYPE = Self::WrappedType> { - self.unwrap().in_env(env) - } -} - -macro_rules! wrapper { - ($name:ident, $term_type:path) => { - #[derive(PartialEq, Eq, Clone, Copy)] - pub struct $name<'a>(Term<'a>); - - impl<'a> $crate::types::wrapper::Wrapper<'a> for $name<'a> { - const WrappedType: TermType = $term_type; - - unsafe fn wrap_unsafe(term: Term<'a>) -> Self { - $name(term) - } - - fn unwrap(&self) -> Term<'a> { - self.0 - } - } - }; -} - -pub(crate) use wrapper; - -use crate::{Env, Term}; diff --git a/rustler/src/wrapped_types/list.rs b/rustler/src/wrapped_types/list.rs new file mode 100644 index 00000000..d92e7611 --- /dev/null +++ b/rustler/src/wrapped_types/list.rs @@ -0,0 +1,247 @@ +use super::wrapper; + +use std::mem::MaybeUninit; + +use crate::sys::{ + enif_get_list_cell, enif_get_list_length, enif_make_list_cell, enif_make_reverse_list, + get_enif_make_list, +}; +use crate::wrapper::{list, NIF_TERM}; +use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; + +wrapper!( + /// Enables iteration over the items in the list. + /// + /// Although this behaves like a standard Rust iterator + /// ([book](https://doc.rust-lang.org/book/iterators.html) / + /// [docs](https://doc.rust-lang.org/std/iter/trait.Iterator.html)), there are a couple of tricky + /// parts to using it. + /// + /// Because the iterator is an iterator over `Term`s, you need to decode the terms before you + /// can do anything with them. + /// + /// ## Example + /// An easy way to decode all terms in a list, is to use the `.map()` function of the iterator, and + /// decode every entry in the list. This will produce an iterator of `Result`s, and will therefore + /// not be directly usable in the way you might immediately expect. + /// + /// For this case, the the `.collect()` function of rust iterators is useful, as it can lift + /// the `Result`s out of the list. (Contains extra type annotations for clarity) + /// + /// ``` + /// # use rustler::{Term, NifResult}; + /// # use rustler::types::list::ListIterator; + /// # fn list_iterator_example(list_term: Term) -> NifResult> { + /// let list_iterator: ListIterator = list_term.decode()?; + /// + /// let result: NifResult> = list_iterator + /// // Produces an iterator of NifResult + /// .map(|x| x.decode::()) + /// // Lifts each value out of the result. Returns Ok(Vec) if successful, the first error + /// // Error(Error) on failure. + /// .collect::>>(); + /// # result + /// # } + /// ``` + struct ListIterator(TermType::List) +); + +impl<'a> Iterator for ListIterator<'a> { + type Item = Term<'a>; + + fn next(&mut self) -> Option> { + match self.get_cell() { + Some((head, tail)) => { + // TODO: This is unsafe as tail might not be a list. + self.0 = tail; + Some(head) + }, + None => { + if self.is_empty_list() { + // We reached the end of the list, finish the iterator. + None + } else { + panic!("list iterator found improper list") + } + } + } + } +} + +//impl<'a, T> Encoder for Iterator where T: Encoder { +// fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { +// let term_arr: Vec = +// self.map(|x| x.encode(env).as_c_arg()).collect(); +// } +//} +impl Encoder for Vec +where + T: Encoder, +{ + fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { + self.as_slice().encode(env) + } +} + +impl<'a, T> Decoder<'a> for Vec +where + T: Decoder<'a>, +{ + fn decode(term: Term<'a>) -> NifResult { + let iter: ListIterator = term.decode()?; + let res: NifResult = iter.map(|x| x.decode::()).collect(); + res + } +} + +impl Encoder for [T] +where + T: Encoder, +{ + fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { + let term_array: Vec = self.iter().map(|x| x.encode(env).as_c_arg()).collect(); + unsafe { Term::new(env, list::make_list(env.as_c_arg(), &term_array)) } + } +} + +impl<'a, T> Encoder for &'a [T] +where + T: Encoder, +{ + fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { + let term_array: Vec = self.iter().map(|x| x.encode(env).as_c_arg()).collect(); + unsafe { Term::new(env, list::make_list(env.as_c_arg(), &term_array)) } + } +} + +impl<'a> ListIterator<'a> { + pub fn new_empty(env: Env<'a>) -> Self { + let term = unsafe { get_enif_make_list()(env.as_c_arg(), 0) }; + unsafe { Self::wrap_unchecked(Term::new(env, term)) } + } + + pub fn len(&self) -> Option { + let mut len: u32 = 0; + let success = + unsafe { enif_get_list_length(self.get_env().as_c_arg(), self.as_c_arg(), &mut len) }; + + if success != 1 { + return None; + } + + Some(len as usize) + } + + pub fn empty(&self) -> bool { + self.is_empty_list() + } + + pub fn get_cell(&self) -> Option<(Term<'a>, Term<'a>)> { + let env = self.get_env(); + let mut head = MaybeUninit::uninit(); + let mut tail = MaybeUninit::uninit(); + let success = unsafe { + enif_get_list_cell( + env.as_c_arg(), + self.as_c_arg(), + head.as_mut_ptr(), + tail.as_mut_ptr(), + ) + }; + + if success != 1 { + return None; + } + + unsafe { + Some(( + Term::new(env, head.assume_init()), + Term::new(env, tail.assume_init()), + )) + } + } + + pub fn prepend(&self, head: impl Encoder) -> Self { + Term::list_prepend(self.unwrap(), head) + } + + pub fn reverse(&self) -> Option { + let env = self.get_env(); + let mut list_out = MaybeUninit::uninit(); + let success = unsafe { + enif_make_reverse_list(env.as_c_arg(), self.as_c_arg(), list_out.as_mut_ptr()) + }; + + if success != 1 { + return None; + } + + let term = unsafe { Self::wrap_unchecked(Term::new(env, list_out.assume_init())) }; + + Some(term) + } +} + +/// ## List terms +impl<'a> Term<'a> { + /// Returns a new empty list. + pub fn list_new_empty(env: Env<'a>) -> ListIterator<'a> { + ListIterator::new_empty(env) + } + + /// Returns an iterator over a list term. + /// See documentation for ListIterator for more information. + /// + /// Returns None if the term is not a list. + pub fn into_list_iterator(self) -> NifResult> { + Ok(ListIterator::wrap(self)?) + } + + /// Returns the length of a list term. + /// + /// Returns None if the term is not a list. + /// + /// ### Elixir equivalent + /// ```elixir + /// length(self_term) + /// ``` + pub fn list_length(self) -> NifResult { + let iter: ListIterator = self.try_into()?; + iter.len().ok_or(Error::BadArg) + } + + /// Unpacks a single cell at the head of a list term, + /// and returns the result as a tuple of (head, tail). + /// + /// Returns None if the term is not a list. + /// + /// ### Elixir equivalent + /// ```elixir + /// [head, tail] = self_term + /// {head, tail} + /// ``` + pub fn list_get_cell(self) -> NifResult, Term<'a>)>> { + let iter: ListIterator = self.try_into()?; + + Ok(iter.get_cell()) + } + + /// Makes a copy of the self list term and reverses it. + /// + /// Returns Err(Error::BadArg) if the term is not a list. + pub fn list_reverse(self) -> NifResult> { + let iter: ListIterator = self.try_into()?; + iter.reverse().ok_or(Error::BadArg) + } + + /// Adds `head` in a list cell with `self` as tail. + pub fn list_prepend(self, head: impl Encoder) -> ListIterator<'a> { + let env = self.get_env(); + let head = head.encode(env); + + unsafe { + let term = enif_make_list_cell(env.as_c_arg(), head.as_c_arg(), self.as_c_arg()); + ListIterator::wrap_unchecked(Term::new(env, term)) + } + } +} diff --git a/rustler/src/types/map.rs b/rustler/src/wrapped_types/map.rs similarity index 86% rename from rustler/src/types/map.rs rename to rustler/src/wrapped_types/map.rs index 0e6e59de..486aa8ea 100644 --- a/rustler/src/types/map.rs +++ b/rustler/src/wrapped_types/map.rs @@ -1,12 +1,66 @@ //! Utilities used to access and create Erlang maps. use super::atom; +use crate::sys::{enif_get_map_value, enif_make_map_put, enif_make_new_map}; use crate::wrapper::map; -use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; +use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; +use std::mem::MaybeUninit; use std::ops::RangeInclusive; -pub fn map_new(env: Env) -> Term { - unsafe { Term::new(env, map::map_new(env.as_c_arg())) } +wrapper!( + /// A wrapper around an Erlang map term. + struct Map(TermType::Map) +); + +impl<'a> Env<'a> { + pub fn new_map(self) -> Map<'a> { + unsafe { Map::wrap_unchecked(Term::new(self, enif_make_new_map(self.as_c_arg()))) } + } +} + +impl<'a> Map<'a> { + pub fn get(&self, key: impl Encoder) -> Option> { + let env = self.get_env(); + let key = key.encode(env); + + let mut result = MaybeUninit::uninit(); + let success = unsafe { + enif_get_map_value( + env.as_c_arg(), + self.as_c_arg(), + key.as_c_arg(), + result.as_mut_ptr(), + ) + }; + + if success != 1 { + return None; + } + + unsafe { Some(Term::new(env, result.assume_init())) } + } + + pub fn put(&self, key: impl Encoder, value: impl Encoder) -> NifResult { + let env = self.get_env(); + let key = key.encode(env); + let value = value.encode(env); + + let mut result = MaybeUninit::uninit(); + let success = unsafe { + enif_make_map_put( + env.as_c_arg(), + self.as_c_arg(), + key.as_c_arg(), + value.as_c_arg(), + result.as_mut_ptr(), + ) + }; + + if success != 1 { + return Err(Error::BadArg); + } + unsafe { Ok(Map::wrap_ptr_unchecked(env, result.assume_init())) } + } } /// ## Map terms @@ -18,7 +72,7 @@ impl<'a> Term<'a> { /// %{} /// ``` pub fn map_new(env: Env<'a>) -> Term<'a> { - map_new(env) + env.new_map().unwrap() } /// Construct a new map from two vectors diff --git a/rustler/src/wrapped_types/mod.rs b/rustler/src/wrapped_types/mod.rs new file mode 100644 index 00000000..c7561f73 --- /dev/null +++ b/rustler/src/wrapped_types/mod.rs @@ -0,0 +1,7 @@ +mod list; +mod map; +mod tuple; +mod wrapper; + +pub(crate) use wrapper::wrapper; +pub(crate) use list::ListIterator; diff --git a/rustler/src/types/tuple.rs b/rustler/src/wrapped_types/tuple.rs similarity index 64% rename from rustler/src/types/tuple.rs rename to rustler/src/wrapped_types/tuple.rs index 5048e710..6e1861f2 100644 --- a/rustler/src/types/tuple.rs +++ b/rustler/src/wrapped_types/tuple.rs @@ -1,21 +1,55 @@ -use crate::wrapper::{tuple, NIF_TERM}; -use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; +use crate::{ + sys::{enif_get_tuple, ERL_NIF_TERM}, + Decoder, Encoder, Env, Error, NifResult, Term, TermType, +}; -/// Convert an Erlang tuple to a Rust vector. (To convert to a Rust tuple, use `term.decode()` -/// instead.) -/// -/// # Errors -/// `badarg` if `term` is not a tuple. -pub fn get_tuple(term: Term) -> Result, Error> { +use std::{ffi::c_int, mem::MaybeUninit, ops::Index}; +wrapper!( + struct Tuple(TermType::Tuple) +); + +pub unsafe fn get_tuple<'a>(term: Term<'a>) -> NifResult<&'a [ERL_NIF_TERM]> { let env = term.get_env(); - unsafe { - match tuple::get_tuple(env.as_c_arg(), term.as_c_arg()) { - Ok(terms) => Ok(terms - .iter() - .map(|x| Term::new(env, *x)) - .collect::>()), - Err(_error) => Err(Error::BadArg), - } + let mut arity: c_int = 0; + let mut array_ptr = MaybeUninit::uninit(); + let success = enif_get_tuple( + env.as_c_arg(), + term.as_c_arg(), + &mut arity, + array_ptr.as_mut_ptr(), + ); + if success != 1 { + return Err(Error::BadArg); + } + let term_array = ::std::slice::from_raw_parts(array_ptr.assume_init(), arity as usize); + Ok(term_array) +} + +impl<'a> Tuple<'a> { + pub fn size(&self) -> usize { + self.get_elements().len() + } + + pub fn get(&self, index: usize) -> Option> { + self.get_elements() + .get(index) + .map(|ptr| unsafe { Term::new(self.get_env(), ptr) }) + } + + /// Convert an Erlang tuple to a Rust vector. (To convert to a Rust tuple, use `term.decode()` + /// instead.) + /// + /// # Errors + /// `badarg` if `term` is not a tuple. + pub fn to_vec(&self) -> Vec> { + self.get_elements() + .iter() + .map(|ptr| unsafe { Term::new(self.get_env(), *ptr) }) + .collect() + } + + fn get_elements(&self) -> &'a [ERL_NIF_TERM] { + unsafe { get_tuple(self.unwrap()).unwrap() } } } diff --git a/rustler/src/wrapped_types/wrapper.rs b/rustler/src/wrapped_types/wrapper.rs new file mode 100644 index 00000000..42726ba0 --- /dev/null +++ b/rustler/src/wrapped_types/wrapper.rs @@ -0,0 +1,106 @@ +use crate::sys::ERL_NIF_TERM; +use crate::{Env, Term, TermType}; + +pub struct WrapperError; + +impl From for crate::Error { + fn from(_: WrapperError) -> Self { + crate::Error::BadArg + } +} + +pub(crate) trait Wrapper<'a>: Sized { + const WRAPPED_TYPE: TermType; + + unsafe fn wrap_ptr_unchecked(env: Env<'a>, ptr: ERL_NIF_TERM) -> Self { + Self::wrap_unchecked(Term::new(env, ptr)) + } + + fn wrap(term: Term<'a>) -> Result { + if term.get_type() == Self::WRAPPED_TYPE { + unsafe { Ok(Self::wrap_unchecked(term)) } + } else { + Err(WrapperError) + } + } + + fn unwrap(&self) -> Term<'a>; + unsafe fn wrap_unchecked(term: Term<'a>) -> Self; +} + +impl<'a, T> From for Term<'a> +where + T: Wrapper<'a>, +{ + fn from(term: T) -> Self { + term.unwrap() + } +} + +macro_rules! wrapper { + ( + $(#[$meta:meta])* + struct $name:ident($term_type:path) + ) => { + $(#[$meta])* + #[derive(PartialEq, Eq, Clone, Copy)] + pub struct $name<'a>(Term<'a>); + + use $crate::types::wrapper::Wrapper; + + impl<'a> $name<'a> { + /// Returns a representation of self in the given Env. + /// + /// If the term is already is in the provided env, it will be directly returned. + /// Otherwise the term will be copied over. + pub fn in_env<'b>(&self, env: Env<'b>) -> $name<'b> { + let term = self.unwrap().in_env(env); + unsafe { $name::wrap_unchecked(term) } + } + } + + impl<'a> Wrapper<'a> for $name<'a> { + const WRAPPED_TYPE: $crate::TermType = $term_type; + + unsafe fn wrap_unchecked(term: Term<'a>) -> Self { + $name(term) + } + + fn unwrap(&self) -> Term<'a> { + self.0 + } + } + + impl<'a> std::ops::Deref for $name<'a> { + type Target = Term<'a>; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl<'a> TryFrom> for $name<'a> { + type Error = $crate::Error; + + fn try_from(term: Term<'a>) -> Result { + use $crate::types::wrapper::Wrapper; + Self::wrap(term).or(Err($crate::Error::BadArg)) + } + } + + impl<'a> $crate::Decoder<'a> for $name<'a> { + fn decode(term: Term<'a>) -> $crate::NifResult { + use $crate::types::wrapper::Wrapper; + Self::wrap(term).or(Err($crate::Error::BadArg)) + } + } + + impl<'a> $crate::Encoder for $name<'a> { + fn encode<'b>(&self, env: $crate::Env<'b>) -> Term<'b> { + self.0.encode(env) + } + } + }; +} + +pub(crate) use wrapper; diff --git a/rustler/src/wrapper/list.rs b/rustler/src/wrapper/list.rs index 33078586..312eb0eb 100644 --- a/rustler/src/wrapper/list.rs +++ b/rustler/src/wrapper/list.rs @@ -1,47 +1,7 @@ use crate::{ - sys::{ - enif_get_list_cell, enif_get_list_length, enif_make_list_cell, enif_make_list_from_array, - enif_make_reverse_list, - }, + sys::enif_make_list_from_array, wrapper::{NIF_ENV, NIF_TERM}, }; -use std::mem::MaybeUninit; - -pub unsafe fn get_list_cell(env: NIF_ENV, list: NIF_TERM) -> Option<(NIF_TERM, NIF_TERM)> { - let mut head = MaybeUninit::uninit(); - let mut tail = MaybeUninit::uninit(); - let success = enif_get_list_cell(env, list, head.as_mut_ptr(), tail.as_mut_ptr()); - - if success != 1 { - return None; - } - Some((head.assume_init(), tail.assume_init())) -} - -pub unsafe fn get_list_length(env: NIF_ENV, list: NIF_TERM) -> Option { - let mut len: u32 = 0; - let success = enif_get_list_length(env, list, &mut len); - - if success != 1 { - return None; - } - Some(len as usize) -} - pub unsafe fn make_list(env: NIF_ENV, arr: &[NIF_TERM]) -> NIF_TERM { enif_make_list_from_array(env, arr.as_ptr(), arr.len() as u32) } - -pub unsafe fn make_list_cell(env: NIF_ENV, head: NIF_TERM, tail: NIF_TERM) -> NIF_TERM { - enif_make_list_cell(env, head, tail) -} - -pub unsafe fn make_reverse_list(env: NIF_ENV, list: NIF_TERM) -> Option { - let mut list_out = MaybeUninit::uninit(); - let success = enif_make_reverse_list(env, list, list_out.as_mut_ptr()); - - if success != 1 { - return None; - } - Some(list_out.assume_init()) -} diff --git a/rustler/src/wrapper/tuple.rs b/rustler/src/wrapper/tuple.rs index 6aa0cc50..621ce19d 100644 --- a/rustler/src/wrapper/tuple.rs +++ b/rustler/src/wrapper/tuple.rs @@ -1,18 +1,7 @@ -use crate::sys::{enif_get_tuple, enif_make_tuple_from_array}; +use crate::sys::enif_make_tuple_from_array; use crate::wrapper::{c_int, NIF_ENV, NIF_ERROR, NIF_TERM}; use std::mem::MaybeUninit; -pub unsafe fn get_tuple<'a>(env: NIF_ENV, term: NIF_TERM) -> Result<&'a [NIF_TERM], NIF_ERROR> { - let mut arity: c_int = 0; - let mut array_ptr = MaybeUninit::uninit(); - let success = enif_get_tuple(env, term, &mut arity, array_ptr.as_mut_ptr()); - if success != 1 { - return Err(NIF_ERROR::BAD_ARG); - } - let term_array = ::std::slice::from_raw_parts(array_ptr.assume_init(), arity as usize); - Ok(term_array) -} - pub unsafe fn make_tuple(env: NIF_ENV, terms: &[NIF_TERM]) -> NIF_TERM { enif_make_tuple_from_array(env, terms.as_ptr(), terms.len() as u32) } diff --git a/rustler_tests/native/rustler_test/src/test_env.rs b/rustler_tests/native/rustler_test/src/test_env.rs index 99f9da4f..956d295d 100644 --- a/rustler_tests/native/rustler_test/src/test_env.rs +++ b/rustler_tests/native/rustler_test/src/test_env.rs @@ -62,12 +62,12 @@ pub fn sublists<'a>(env: Env<'a>, list: Term<'a>) -> NifResult { let reversed_list = saved_reversed_list.load(env); let iter: ListIterator = reversed_list.decode()?; - let empty_list = Vec::::new().encode(env); + let empty_list = ListIterator::new_empty(env); let mut all_sublists = vec![empty_list]; for element in iter { for i in 0..all_sublists.len() { - let new_list = all_sublists[i].list_prepend(element); + let new_list = all_sublists[i].prepend(element); all_sublists.push(new_list); } } From 88c8f7ec96bfdfa487626c8f88f798c252fc359c Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Mon, 6 Jan 2025 09:40:17 +0100 Subject: [PATCH 3/7] Current state --- rustler/src/lib.rs | 8 ++--- rustler/src/types/elixir_struct.rs | 3 +- rustler/src/types/mod.rs | 17 +---------- rustler/src/wrapped_types/map.rs | 4 ++- rustler/src/wrapped_types/mod.rs | 12 ++++++-- .../src/{types => wrapped_types}/reference.rs | 2 ++ rustler/src/wrapped_types/tuple.rs | 30 +++++++++++-------- rustler/src/wrapped_types/wrapper.rs | 6 ++-- rustler/src/wrapper/tuple.rs | 3 +- 9 files changed, 40 insertions(+), 45 deletions(-) rename rustler/src/{types => wrapped_types}/reference.rs (94%) diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index fc3e7047..686511f7 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -36,16 +36,14 @@ pub mod types; mod term; mod wrapped_types; -pub use crate::wrapped_types::{ - ListIterator, Map -}; pub use crate::term::Term; pub use crate::types::{ - Atom, Binary, Decoder, Encoder, ErlOption, LocalPid, MapIterator, NewBinary, - OwnedBinary, Reference, + Atom, Binary, Decoder, Encoder, ErlOption, LocalPid, NewBinary, OwnedBinary }; +pub use crate::wrapped_types::{ListIterator, Reference, MapIterator, Map, Tuple}; + #[cfg(feature = "big_integer")] pub use crate::types::BigInt; diff --git a/rustler/src/types/elixir_struct.rs b/rustler/src/types/elixir_struct.rs index 7bde54a0..806d2f2d 100644 --- a/rustler/src/types/elixir_struct.rs +++ b/rustler/src/types/elixir_struct.rs @@ -8,8 +8,7 @@ //! `#[module = "Elixir.TheStructModule"]`. use super::atom::{self, Atom}; -use super::map::Map; -use crate::{Env, Error, NifResult, Term}; +use crate::{Env, Error, Map, NifResult, Term}; pub fn get_ex_struct_name(map: Term) -> NifResult { // In an Elixir struct the value in the __struct__ field is always an atom. diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index 829cded8..c1031421 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -1,8 +1,6 @@ +use crate::wrapped_types::MapIterator; use crate::{Env, Error, NifResult, Term}; -#[macro_use] -mod wrapper; - #[macro_use] pub mod atom; pub use crate::types::atom::Atom; @@ -15,28 +13,15 @@ pub mod big_int; #[cfg(feature = "big_integer")] pub use num_bigint::BigInt; -#[doc(hidden)] -pub mod list; -pub use crate::types::list::ListIterator; - -#[doc(hidden)] -pub mod map; -pub use self::map::MapIterator; - #[doc(hidden)] pub mod primitive; #[doc(hidden)] pub mod string; -pub mod tuple; #[doc(hidden)] pub mod local_pid; pub use self::local_pid::LocalPid; -#[doc(hidden)] -pub mod reference; -pub use self::reference::Reference; - pub mod i128; pub mod path; diff --git a/rustler/src/wrapped_types/map.rs b/rustler/src/wrapped_types/map.rs index 486aa8ea..94009598 100644 --- a/rustler/src/wrapped_types/map.rs +++ b/rustler/src/wrapped_types/map.rs @@ -1,12 +1,14 @@ //! Utilities used to access and create Erlang maps. -use super::atom; use crate::sys::{enif_get_map_value, enif_make_map_put, enif_make_new_map}; +use crate::types::atom; use crate::wrapper::map; use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; use std::mem::MaybeUninit; use std::ops::RangeInclusive; +use super::wrapper; + wrapper!( /// A wrapper around an Erlang map term. struct Map(TermType::Map) diff --git a/rustler/src/wrapped_types/mod.rs b/rustler/src/wrapped_types/mod.rs index c7561f73..49840de7 100644 --- a/rustler/src/wrapped_types/mod.rs +++ b/rustler/src/wrapped_types/mod.rs @@ -1,7 +1,13 @@ mod list; mod map; -mod tuple; -mod wrapper; +mod reference; +pub mod tuple; +pub mod wrapper; + +pub use list::ListIterator; +pub use map::{Map, MapIterator}; +pub use reference::Reference; +pub use tuple::Tuple; pub(crate) use wrapper::wrapper; -pub(crate) use list::ListIterator; +pub(crate) use wrapper::Wrapper; diff --git a/rustler/src/types/reference.rs b/rustler/src/wrapped_types/reference.rs similarity index 94% rename from rustler/src/types/reference.rs rename to rustler/src/wrapped_types/reference.rs index 4d4873e0..4c4261d2 100644 --- a/rustler/src/types/reference.rs +++ b/rustler/src/wrapped_types/reference.rs @@ -2,6 +2,8 @@ use crate::{Env, Term, TermType}; use crate::sys::enif_make_ref; +use super::wrapper; + wrapper!{ struct Reference(TermType::Ref) } diff --git a/rustler/src/wrapped_types/tuple.rs b/rustler/src/wrapped_types/tuple.rs index 6e1861f2..a814c642 100644 --- a/rustler/src/wrapped_types/tuple.rs +++ b/rustler/src/wrapped_types/tuple.rs @@ -1,9 +1,10 @@ -use crate::{ - sys::{enif_get_tuple, ERL_NIF_TERM}, - Decoder, Encoder, Env, Error, NifResult, Term, TermType, -}; +use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; +use crate::sys::{enif_get_tuple, enif_make_tuple_from_array, ERL_NIF_TERM}; -use std::{ffi::c_int, mem::MaybeUninit, ops::Index}; +use std::ffi::c_int; +use std::mem::MaybeUninit; + +use super::wrapper; wrapper!( struct Tuple(TermType::Tuple) ); @@ -33,7 +34,7 @@ impl<'a> Tuple<'a> { pub fn get(&self, index: usize) -> Option> { self.get_elements() .get(index) - .map(|ptr| unsafe { Term::new(self.get_env(), ptr) }) + .map(|ptr| unsafe { Term::new(self.get_env(), *ptr) }) } /// Convert an Erlang tuple to a Rust vector. (To convert to a Rust tuple, use `term.decode()` @@ -56,8 +57,13 @@ impl<'a> Tuple<'a> { /// Convert a vector of terms to an Erlang tuple. (To convert from a Rust tuple to an Erlang tuple, /// use `Encoder` instead.) pub fn make_tuple<'a>(env: Env<'a>, terms: &[Term]) -> Term<'a> { - let c_terms: Vec = terms.iter().map(|term| term.as_c_arg()).collect(); - unsafe { Term::new(env, tuple::make_tuple(env.as_c_arg(), &c_terms)) } + let c_terms: Vec = terms.iter().map(|term| term.as_c_arg()).collect(); + unsafe { + let term = + enif_make_tuple_from_array(env.as_c_arg(), c_terms.as_ptr(), c_terms.len() as u32); + Term::new(env, term) + } + // unsafe { Term::new(env, tuple::make_tuple(env.as_c_arg(), &c_terms)) } } /// Helper macro to emit tuple-like syntax. Wraps its arguments in parentheses, and adds a comma if @@ -83,10 +89,8 @@ macro_rules! impl_nifencoder_nifdecoder_for_tuple { Encoder for tuple!( $( $tyvar ),* ) { fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - let arr = [ $( Encoder::encode(&self.$index, env).as_c_arg() ),* ]; - unsafe { - Term::new(env, tuple::make_tuple(env.as_c_arg(), &arr)) - } + let arr = [ $( Encoder::encode(&self.$index, env) ),* ]; + make_tuple(env, &arr) } } @@ -95,7 +99,7 @@ macro_rules! impl_nifencoder_nifdecoder_for_tuple { { fn decode(term: Term<'a>) -> NifResult { - match unsafe { tuple::get_tuple(term.get_env().as_c_arg(), term.as_c_arg()) } { + match unsafe { get_tuple(term) } { Ok(elements) if elements.len() == count!( $( $index ),* ) => Ok(tuple!( $( (<$tyvar as Decoder>::decode( diff --git a/rustler/src/wrapped_types/wrapper.rs b/rustler/src/wrapped_types/wrapper.rs index 42726ba0..ec0fbb66 100644 --- a/rustler/src/wrapped_types/wrapper.rs +++ b/rustler/src/wrapped_types/wrapper.rs @@ -46,7 +46,7 @@ macro_rules! wrapper { #[derive(PartialEq, Eq, Clone, Copy)] pub struct $name<'a>(Term<'a>); - use $crate::types::wrapper::Wrapper; + use $crate::wrapped_types::Wrapper; impl<'a> $name<'a> { /// Returns a representation of self in the given Env. @@ -83,14 +83,14 @@ macro_rules! wrapper { type Error = $crate::Error; fn try_from(term: Term<'a>) -> Result { - use $crate::types::wrapper::Wrapper; + use $crate::wrapped_types::Wrapper; Self::wrap(term).or(Err($crate::Error::BadArg)) } } impl<'a> $crate::Decoder<'a> for $name<'a> { fn decode(term: Term<'a>) -> $crate::NifResult { - use $crate::types::wrapper::Wrapper; + use $crate::wrapped_types::Wrapper; Self::wrap(term).or(Err($crate::Error::BadArg)) } } diff --git a/rustler/src/wrapper/tuple.rs b/rustler/src/wrapper/tuple.rs index 621ce19d..88520dcd 100644 --- a/rustler/src/wrapper/tuple.rs +++ b/rustler/src/wrapper/tuple.rs @@ -1,6 +1,5 @@ use crate::sys::enif_make_tuple_from_array; -use crate::wrapper::{c_int, NIF_ENV, NIF_ERROR, NIF_TERM}; -use std::mem::MaybeUninit; +use crate::wrapper::{NIF_ENV, NIF_TERM}; pub unsafe fn make_tuple(env: NIF_ENV, terms: &[NIF_TERM]) -> NIF_TERM { enif_make_tuple_from_array(env, terms.as_ptr(), terms.len() as u32) From ff8e419ffb1ca3f6b011eda1b3dae243cadc8672 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Sun, 12 Jan 2025 20:34:09 +0100 Subject: [PATCH 4/7] Fix serde feature --- rustler/src/serde/de.rs | 5 +---- rustler/src/serde/ser.rs | 4 ++-- rustler/src/serde/util.rs | 7 +++++-- rustler/src/wrapped_types/tuple.rs | 6 +++++- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/rustler/src/serde/de.rs b/rustler/src/serde/de.rs index 309a1999..f99d74b8 100644 --- a/rustler/src/serde/de.rs +++ b/rustler/src/serde/de.rs @@ -1,8 +1,5 @@ use crate::serde::{atoms, error::Error, util}; -use crate::{ - types::{ListIterator, MapIterator}, - Term, TermType, -}; +use crate::{ListIterator, MapIterator, Term, TermType}; use serde::{ de::{ self, Deserialize, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, diff --git a/rustler/src/serde/ser.rs b/rustler/src/serde/ser.rs index 6071c5a1..92d9dd9b 100644 --- a/rustler/src/serde/ser.rs +++ b/rustler/src/serde/ser.rs @@ -2,7 +2,7 @@ use std::io::Write; use crate::serde::{atoms, error::Error, util}; use crate::wrapper::list::make_list; -use crate::{types::tuple, Encoder, Env, OwnedBinary, Term}; +use crate::{Encoder, Env, OwnedBinary, Term, Tuple}; use serde::ser::{self, Serialize}; #[inline] @@ -336,7 +336,7 @@ impl<'a> SequenceSerializer<'a> { #[inline] fn to_tuple(&self) -> Result, Error> { - Ok(tuple::make_tuple(self.ser.env, &self.items)) + Ok(Tuple::make(self.ser.env, &self.items)) } } diff --git a/rustler/src/serde/util.rs b/rustler/src/serde/util.rs index d479139a..84b6f6d2 100644 --- a/rustler/src/serde/util.rs +++ b/rustler/src/serde/util.rs @@ -1,5 +1,5 @@ use crate::serde::{atoms, Error}; -use crate::{types::tuple, Binary, Decoder, Encoder, Env, Term}; +use crate::{Binary, Decoder, Encoder, Env, Term, Tuple}; /// Converts an `&str` to either an existing atom or an Elixir bitstring. pub fn str_to_term<'a>(env: &Env<'a>, string: &str) -> Result, Error> { @@ -62,7 +62,10 @@ pub fn validate_tuple(term: Term, len: Option) -> Result, Error return Err(Error::ExpectedTuple); } - let tuple = tuple::get_tuple(term).or(Err(Error::ExpectedTuple))?; + let tuple = Tuple::try_from(term) + .or(Err(Error::ExpectedTuple))? + .to_vec(); + match len { None => Ok(tuple), Some(len) => { diff --git a/rustler/src/wrapped_types/tuple.rs b/rustler/src/wrapped_types/tuple.rs index a814c642..758485ba 100644 --- a/rustler/src/wrapped_types/tuple.rs +++ b/rustler/src/wrapped_types/tuple.rs @@ -1,5 +1,5 @@ -use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; use crate::sys::{enif_get_tuple, enif_make_tuple_from_array, ERL_NIF_TERM}; +use crate::{Decoder, Encoder, Env, Error, NifResult, Term, TermType}; use std::ffi::c_int; use std::mem::MaybeUninit; @@ -27,6 +27,10 @@ pub unsafe fn get_tuple<'a>(term: Term<'a>) -> NifResult<&'a [ERL_NIF_TERM]> { } impl<'a> Tuple<'a> { + pub fn make(env: Env<'a>, terms: &[Term<'a>]) -> Term<'a> { + make_tuple(env, terms) + } + pub fn size(&self) -> usize { self.get_elements().len() } From 5205a56ed9c25ae1591f10923972b625457ea2f6 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Sun, 12 Jan 2025 20:48:46 +0100 Subject: [PATCH 5/7] Fix doctest --- rustler/src/wrapped_types/list.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rustler/src/wrapped_types/list.rs b/rustler/src/wrapped_types/list.rs index d92e7611..9abf076f 100644 --- a/rustler/src/wrapped_types/list.rs +++ b/rustler/src/wrapped_types/list.rs @@ -29,8 +29,7 @@ wrapper!( /// the `Result`s out of the list. (Contains extra type annotations for clarity) /// /// ``` - /// # use rustler::{Term, NifResult}; - /// # use rustler::types::list::ListIterator; + /// # use rustler::{Term, NifResult, ListIterator}; /// # fn list_iterator_example(list_term: Term) -> NifResult> { /// let list_iterator: ListIterator = list_term.decode()?; /// @@ -55,7 +54,7 @@ impl<'a> Iterator for ListIterator<'a> { // TODO: This is unsafe as tail might not be a list. self.0 = tail; Some(head) - }, + } None => { if self.is_empty_list() { // We reached the end of the list, finish the iterator. From b93691f6f1eb26e8733e35ba461048064efa1e92 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Sun, 12 Jan 2025 23:06:13 +0100 Subject: [PATCH 6/7] Fix more tests, some renamings, move function to Env --- rustler/src/lib.rs | 4 ++-- rustler/src/serde/ser.rs | 4 ++-- rustler/src/wrapped_types/list.rs | 4 ++-- rustler/src/wrapped_types/reference.rs | 2 +- rustler/src/wrapped_types/tuple.rs | 21 ++++++++++++------- .../src/encode_decode_templates.rs | 2 +- rustler_codegen/src/record.rs | 17 +++++++++------ rustler_codegen/src/tagged_enum.rs | 12 +++++------ rustler_codegen/src/tuple.rs | 16 ++++++++------ .../native/rustler_serde_test/src/lib.rs | 6 +++--- .../native/rustler_test/src/test_env.rs | 3 +-- .../native/rustler_test/src/test_map.rs | 16 +++++++------- 12 files changed, 59 insertions(+), 48 deletions(-) diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index 686511f7..9b391722 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -39,10 +39,10 @@ mod wrapped_types; pub use crate::term::Term; pub use crate::types::{ - Atom, Binary, Decoder, Encoder, ErlOption, LocalPid, NewBinary, OwnedBinary + Atom, Binary, Decoder, Encoder, ErlOption, LocalPid, NewBinary, OwnedBinary, }; -pub use crate::wrapped_types::{ListIterator, Reference, MapIterator, Map, Tuple}; +pub use crate::wrapped_types::{ListIterator, Map, MapIterator, Reference, Tuple}; #[cfg(feature = "big_integer")] pub use crate::types::BigInt; diff --git a/rustler/src/serde/ser.rs b/rustler/src/serde/ser.rs index 92d9dd9b..06481d18 100644 --- a/rustler/src/serde/ser.rs +++ b/rustler/src/serde/ser.rs @@ -2,7 +2,7 @@ use std::io::Write; use crate::serde::{atoms, error::Error, util}; use crate::wrapper::list::make_list; -use crate::{Encoder, Env, OwnedBinary, Term, Tuple}; +use crate::{Encoder, Env, OwnedBinary, Term}; use serde::ser::{self, Serialize}; #[inline] @@ -336,7 +336,7 @@ impl<'a> SequenceSerializer<'a> { #[inline] fn to_tuple(&self) -> Result, Error> { - Ok(Tuple::make(self.ser.env, &self.items)) + Ok(self.ser.env.make_tuple(&self.items).into()) } } diff --git a/rustler/src/wrapped_types/list.rs b/rustler/src/wrapped_types/list.rs index 9abf076f..5a3aa3a3 100644 --- a/rustler/src/wrapped_types/list.rs +++ b/rustler/src/wrapped_types/list.rs @@ -103,7 +103,7 @@ where } } -impl<'a, T> Encoder for &'a [T] +impl Encoder for &[T] where T: Encoder, { @@ -131,7 +131,7 @@ impl<'a> ListIterator<'a> { Some(len as usize) } - pub fn empty(&self) -> bool { + pub fn is_empty(&self) -> bool { self.is_empty_list() } diff --git a/rustler/src/wrapped_types/reference.rs b/rustler/src/wrapped_types/reference.rs index 4c4261d2..fcb8aedf 100644 --- a/rustler/src/wrapped_types/reference.rs +++ b/rustler/src/wrapped_types/reference.rs @@ -4,7 +4,7 @@ use crate::sys::enif_make_ref; use super::wrapper; -wrapper!{ +wrapper! { struct Reference(TermType::Ref) } diff --git a/rustler/src/wrapped_types/tuple.rs b/rustler/src/wrapped_types/tuple.rs index 758485ba..8a775c98 100644 --- a/rustler/src/wrapped_types/tuple.rs +++ b/rustler/src/wrapped_types/tuple.rs @@ -9,7 +9,7 @@ wrapper!( struct Tuple(TermType::Tuple) ); -pub unsafe fn get_tuple<'a>(term: Term<'a>) -> NifResult<&'a [ERL_NIF_TERM]> { +pub unsafe fn get_tuple(term: Term<'_>) -> NifResult<&[ERL_NIF_TERM]> { let env = term.get_env(); let mut arity: c_int = 0; let mut array_ptr = MaybeUninit::uninit(); @@ -27,11 +27,11 @@ pub unsafe fn get_tuple<'a>(term: Term<'a>) -> NifResult<&'a [ERL_NIF_TERM]> { } impl<'a> Tuple<'a> { - pub fn make(env: Env<'a>, terms: &[Term<'a>]) -> Term<'a> { - make_tuple(env, terms) + pub fn is_empty(&self) -> bool { + self.len() == 0 } - pub fn size(&self) -> usize { + pub fn len(&self) -> usize { self.get_elements().len() } @@ -58,16 +58,21 @@ impl<'a> Tuple<'a> { } } +impl<'a> Env<'a> { + pub fn make_tuple(&self, terms: &[Term<'a>]) -> Tuple<'a> { + make_tuple(*self, terms) + } +} + /// Convert a vector of terms to an Erlang tuple. (To convert from a Rust tuple to an Erlang tuple, /// use `Encoder` instead.) -pub fn make_tuple<'a>(env: Env<'a>, terms: &[Term]) -> Term<'a> { +pub fn make_tuple<'a>(env: Env<'a>, terms: &[Term]) -> Tuple<'a> { let c_terms: Vec = terms.iter().map(|term| term.as_c_arg()).collect(); unsafe { let term = enif_make_tuple_from_array(env.as_c_arg(), c_terms.as_ptr(), c_terms.len() as u32); - Term::new(env, term) + Term::new(env, term).try_into().unwrap() } - // unsafe { Term::new(env, tuple::make_tuple(env.as_c_arg(), &c_terms)) } } /// Helper macro to emit tuple-like syntax. Wraps its arguments in parentheses, and adds a comma if @@ -94,7 +99,7 @@ macro_rules! impl_nifencoder_nifdecoder_for_tuple { { fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { let arr = [ $( Encoder::encode(&self.$index, env) ),* ]; - make_tuple(env, &arr) + make_tuple(env, &arr).into() } } diff --git a/rustler_codegen/src/encode_decode_templates.rs b/rustler_codegen/src/encode_decode_templates.rs index f1e005ad..a6dab02a 100644 --- a/rustler_codegen/src/encode_decode_templates.rs +++ b/rustler_codegen/src/encode_decode_templates.rs @@ -153,7 +153,7 @@ pub(crate) fn encoder(ctx: &Context, inner: TokenStream) -> TokenStream { impl #impl_generics ::rustler::Encoder for #ident #ty_generics #where_clause { #[allow(clippy::needless_borrow)] fn encode<'__rustler__encode_lifetime>(&self, env: ::rustler::Env<'__rustler__encode_lifetime>) -> ::rustler::Term<'__rustler__encode_lifetime> { - #inner + #inner.into() } } } diff --git a/rustler_codegen/src/record.rs b/rustler_codegen/src/record.rs index 6d418e87..4a85a7ae 100644 --- a/rustler_codegen/src/record.rs +++ b/rustler_codegen/src/record.rs @@ -103,7 +103,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T quote! { use #atoms_module_name::*; - let terms = match ::rustler::types::tuple::get_tuple(term) { + let terms = match ::rustler::Tuple::try_from(term) { Err(_) => return Err(::rustler::Error::RaiseTerm( Box::new(format!("Invalid Record structure for {}", #struct_name_str)))), Ok(value) => value, @@ -113,19 +113,24 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T return Err(::rustler::Error::RaiseAtom("invalid_record")); } - let tag : ::rustler::types::atom::Atom = terms[0].decode()?; + let tag : ::rustler::types::atom::Atom = terms.get(0).unwrap().decode()?; if tag != atom_tag() { return Err(::rustler::Error::RaiseAtom("invalid_record")); } - fn try_decode_index<'a, T>(terms: &[::rustler::Term<'a>], pos_in_struct: &str, index: usize) -> ::rustler::NifResult + fn try_decode_index<'a, T>(terms: &::rustler::Tuple<'a>, pos_in_struct: &str, index: usize) -> ::rustler::NifResult where T: rustler::Decoder<'a>, { - match ::rustler::Decoder::decode(terms[index]) { + use ::rustler::{Decoder, Error}; + + let term = terms.get(index).ok_or_else(|| Error::BadArg)?; + + match term.decode() { Err(_) => Err(::rustler::Error::RaiseTerm(Box::new( - format!("Could not decode field {} on Record {}", pos_in_struct, #struct_name_str)))), + format!("Could not decode field {} on Record {}", pos_in_struct, #struct_name_str)) + )), Ok(value) => Ok(value) } } @@ -166,7 +171,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T use ::rustler::Encoder; let arr = #field_list_ast; - ::rustler::types::tuple::make_tuple(env, &arr) + env.make_tuple(&arr) }, ) } diff --git a/rustler_codegen/src/tagged_enum.rs b/rustler_codegen/src/tagged_enum.rs index 18678b36..3ddcd8f5 100644 --- a/rustler_codegen/src/tagged_enum.rs +++ b/rustler_codegen/src/tagged_enum.rs @@ -141,10 +141,10 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) if let Ok(unit) = ::rustler::types::atom::Atom::from_term(term) { #(#unit_decoders)* - } else if let Ok(tuple) = ::rustler::types::tuple::get_tuple(term) { + } else if let Ok(tuple) = ::rustler::Tuple::try_from(term) { let name = tuple .get(0) - .and_then(|&first| ::rustler::types::atom::Atom::from_term(first).ok()) + .and_then(|first| ::rustler::types::atom::Atom::from_term(first).ok()) .ok_or(::rustler::Error::RaiseAtom("invalid_variant"))?; #(#named_unnamed_decoders)* } @@ -208,7 +208,7 @@ fn gen_unnamed_decoder<'a>( let i = i + 1; let ty = &f.ty; quote! { - <#ty as ::rustler::Decoder>::decode(tuple[#i]).map_err(|_| ::rustler::Error::RaiseTerm( + <#ty as ::rustler::Decoder>::decode(tuple.get(#i).unwrap()).map_err(|_| ::rustler::Error::RaiseTerm( Box::new(format!("Could not decode field on position {}", #i)) ))? } @@ -250,7 +250,7 @@ fn gen_named_decoder( let enum_name_string = enum_name.to_string(); let assignment = quote_spanned! { field.span() => - let #variable = try_decode_field(tuple[1], #atom_fun()).map_err(|_|{ + let #variable = try_decode_field(tuple.get(1).unwrap(), #atom_fun()).map_err(|_|{ ::rustler::Error::RaiseTerm(Box::new(format!( "Could not decode field '{}' on Enum '{}'", #ident_string, #enum_name_string @@ -267,7 +267,7 @@ fn gen_named_decoder( quote! { if tuple.len() == 2 && name == #atom_fn() { - let len = tuple[1].map_size().map_err(|_| ::rustler::Error::RaiseTerm(Box::new( + let len = tuple.get(1).unwrap().map_size().map_err(|_| ::rustler::Error::RaiseTerm(Box::new( "The second element of the tuple must be a map" )))?; #(#assignments)* @@ -334,7 +334,7 @@ fn gen_named_encoder( } => { let map = ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]) .expect("Failed to create map"); - ::rustler::types::tuple::make_tuple(env, &[::rustler::Encoder::encode(&#atom_fn(), env), map]) + env.make_tuple(&[::rustler::Encoder::encode(&#atom_fn(), env), map]).into() } } } diff --git a/rustler_codegen/src/tuple.rs b/rustler_codegen/src/tuple.rs index 900f5608..6a4a6de5 100644 --- a/rustler_codegen/src/tuple.rs +++ b/rustler_codegen/src/tuple.rs @@ -84,18 +84,22 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { super::encode_decode_templates::decoder( ctx, quote! { - let terms = ::rustler::types::tuple::get_tuple(term)?; + let terms = ::rustler::Tuple::try_from(term)?; if terms.len() != #field_num { return Err(::rustler::Error::BadArg); } - fn try_decode_index<'a, T>(terms: &[::rustler::Term<'a>], pos_in_struct: &str, index: usize) -> ::rustler::NifResult + fn try_decode_index<'a, T>(terms: &::rustler::Tuple<'a>, pos_in_struct: &str, index: usize) -> ::rustler::NifResult where T: rustler::Decoder<'a>, { - match ::rustler::Decoder::decode(terms[index]) { - Err(_) => Err(::rustler::Error::RaiseTerm(Box::new( - format!("Could not decode field {} on {}", pos_in_struct, #struct_name_str)))), + use ::rustler::{Decoder, Error}; + + let term = terms.get(index).ok_or_else(|| Error::BadArg)?; + match term.decode() { + Err(_) => Err(Error::RaiseTerm(Box::new( + format!("Could not decode field {} on {}", pos_in_struct, #struct_name_str)) + )), Ok(value) => Ok(value) } } @@ -131,7 +135,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field]) -> TokenStream { quote! { use ::rustler::Encoder; let arr = #field_list_ast; - ::rustler::types::tuple::make_tuple(env, &arr) + env.make_tuple(&arr) }, ) } diff --git a/rustler_tests/native/rustler_serde_test/src/lib.rs b/rustler_tests/native/rustler_serde_test/src/lib.rs index 4fba9361..9e3f4207 100644 --- a/rustler_tests/native/rustler_serde_test/src/lib.rs +++ b/rustler_tests/native/rustler_serde_test/src/lib.rs @@ -11,7 +11,7 @@ mod types; use crate::types::Animal; use rustler::serde::{atoms, Deserializer, Error, Serializer}; -use rustler::{types::tuple, Encoder, Env, NifResult, SerdeTerm, Term}; +use rustler::{Encoder, Env, NifResult, SerdeTerm, Term}; init!("Elixir.SerdeRustlerTests"); @@ -56,10 +56,10 @@ where fn ok_tuple<'a>(env: Env<'a>, term: Term<'a>) -> Term<'a> { let ok_atom_term = atoms::ok().encode(env); - tuple::make_tuple(env, &[ok_atom_term, term]) + env.make_tuple(&[ok_atom_term, term]).into() } fn error_tuple<'a>(env: Env<'a>, reason_term: Term<'a>) -> Term<'a> { let err_atom_term = atoms::error().encode(env); - tuple::make_tuple(env, &[err_atom_term, reason_term]) + env.make_tuple(&[err_atom_term, reason_term]).into() } diff --git a/rustler_tests/native/rustler_test/src/test_env.rs b/rustler_tests/native/rustler_test/src/test_env.rs index 956d295d..93a5f86a 100644 --- a/rustler_tests/native/rustler_test/src/test_env.rs +++ b/rustler_tests/native/rustler_test/src/test_env.rs @@ -1,8 +1,7 @@ use rustler::env::{OwnedEnv, SavedTerm, SendError}; use rustler::types::atom; -use rustler::types::list::ListIterator; use rustler::types::LocalPid; -use rustler::{Atom, Encoder, Env, NifResult, Reference, Term}; +use rustler::{Atom, Encoder, Env, ListIterator, NifResult, Reference, Term}; use std::thread; // Send a message to several PIDs. diff --git a/rustler_tests/native/rustler_test/src/test_map.rs b/rustler_tests/native/rustler_test/src/test_map.rs index d4e1d5ed..1a6e0b92 100644 --- a/rustler_tests/native/rustler_test/src/test_map.rs +++ b/rustler_tests/native/rustler_test/src/test_map.rs @@ -1,6 +1,4 @@ -use rustler::types::map::MapIterator; -use rustler::types::tuple::make_tuple; -use rustler::{Encoder, Env, Error, ListIterator, NifResult, Term}; +use rustler::{Encoder, Env, Error, ListIterator, MapIterator, NifResult, Term, Tuple}; #[rustler::nif] pub fn sum_map_values(iter: MapIterator) -> NifResult { @@ -11,31 +9,31 @@ pub fn sum_map_values(iter: MapIterator) -> NifResult { } #[rustler::nif] -pub fn map_entries<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { +pub fn map_entries<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { let mut vec = vec![]; for (key, value) in iter { let key_string = key.decode::()?; vec.push((key_string, value)); } - let erlang_pairs: Vec = vec + let erlang_pairs: Vec<_> = vec .into_iter() - .map(|(key, value)| make_tuple(env, &[key.encode(env), value])) + .map(|(key, value)| env.make_tuple(&[key.encode(env), value])) .collect(); Ok(erlang_pairs) } #[rustler::nif] -pub fn map_entries_reversed<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { +pub fn map_entries_reversed<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult>> { let mut vec = vec![]; for (key, value) in iter.rev() { let key_string = key.decode::()?; vec.push((key_string, value)); } - let erlang_pairs: Vec = vec + let erlang_pairs: Vec<_> = vec .into_iter() - .map(|(key, value)| make_tuple(env, &[key.encode(env), value])) + .map(|(key, value)| env.make_tuple(&[key.encode(env), value])) .collect(); Ok(erlang_pairs) } From c6b9bd2bcea9e5ae949054f8d5fbc72528699840 Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Mon, 13 Jan 2025 21:06:54 +0100 Subject: [PATCH 7/7] Current state --- rustler/src/lib.rs | 2 +- rustler/src/serde/ser.rs | 6 +++++- rustler/src/types/atom.rs | 13 ++++++++++++ rustler/src/wrapped_types/map.rs | 20 +++++++++++++++++++ rustler/src/wrapped_types/mod.rs | 2 +- rustler/src/wrapped_types/wrapper.rs | 2 +- rustler_codegen/src/map.rs | 4 ++-- .../native/rustler_test/src/test_map.rs | 10 +++++----- 8 files changed, 48 insertions(+), 11 deletions(-) diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index 9b391722..bb4d820a 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -42,7 +42,7 @@ pub use crate::types::{ Atom, Binary, Decoder, Encoder, ErlOption, LocalPid, NewBinary, OwnedBinary, }; -pub use crate::wrapped_types::{ListIterator, Map, MapIterator, Reference, Tuple}; +pub use crate::wrapped_types::{ListIterator, Map, MapIterator, Reference, Tuple, Wrapper}; #[cfg(feature = "big_integer")] pub use crate::types::BigInt; diff --git a/rustler/src/serde/ser.rs b/rustler/src/serde/ser.rs index 06481d18..09e11113 100644 --- a/rustler/src/serde/ser.rs +++ b/rustler/src/serde/ser.rs @@ -460,7 +460,11 @@ impl<'a> MapSerializer<'a> { #[inline] fn to_map(&self) -> Result, Error> { - Term::map_from_arrays(self.ser.env, &self.keys, &self.values).or(Err(Error::InvalidMap)) + self.ser + .env + .map_from_arrays(&self.keys, &self.values) + .map(|map| map.into()) + .or(Err(Error::InvalidMap)) } #[inline] diff --git a/rustler/src/types/atom.rs b/rustler/src/types/atom.rs index 594e4ea1..115e7d2e 100644 --- a/rustler/src/types/atom.rs +++ b/rustler/src/types/atom.rs @@ -1,5 +1,6 @@ use crate::wrapper::atom; use crate::wrapper::NIF_TERM; +use crate::Wrapper; use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; // Atoms are a special case of a term. They can be stored and used on all envs regardless of where @@ -81,6 +82,18 @@ impl Atom { } } +impl<'a> Wrapper<'a> for Atom { + const WRAPPED_TYPE: crate::TermType = crate::TermType::Atom; + + fn unwrap(&self) -> Term<'a> { + unimplemented!() + } + + unsafe fn wrap_unchecked(term: Term<'a>) -> Self { + Atom::from_nif_term(term.as_c_arg()) + } +} + use std::fmt; impl fmt::Debug for Atom { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { diff --git a/rustler/src/wrapped_types/map.rs b/rustler/src/wrapped_types/map.rs index 94009598..8a9cd63b 100644 --- a/rustler/src/wrapped_types/map.rs +++ b/rustler/src/wrapped_types/map.rs @@ -18,6 +18,26 @@ impl<'a> Env<'a> { pub fn new_map(self) -> Map<'a> { unsafe { Map::wrap_unchecked(Term::new(self, enif_make_new_map(self.as_c_arg()))) } } + + pub fn map_from_pairs(self, pairs: &[(impl Encoder, impl Encoder)]) -> NifResult> { + Term::map_from_pairs(self, pairs).map(|res| res.try_into().unwrap()) + } + + pub fn map_from_arrays( + self, + keys: &[impl Encoder], + values: &[impl Encoder], + ) -> NifResult> { + Term::map_from_arrays(self, keys, values).map(|res| res.try_into().unwrap()) + } + + pub fn map_from_term_arrays( + self, + keys: &[Term<'a>], + values: &[Term<'a>], + ) -> NifResult> { + Term::map_from_term_arrays(self, keys, values).map(|res| res.try_into().unwrap()) + } } impl<'a> Map<'a> { diff --git a/rustler/src/wrapped_types/mod.rs b/rustler/src/wrapped_types/mod.rs index 49840de7..4ef1da54 100644 --- a/rustler/src/wrapped_types/mod.rs +++ b/rustler/src/wrapped_types/mod.rs @@ -10,4 +10,4 @@ pub use reference::Reference; pub use tuple::Tuple; pub(crate) use wrapper::wrapper; -pub(crate) use wrapper::Wrapper; +pub use wrapper::Wrapper; diff --git a/rustler/src/wrapped_types/wrapper.rs b/rustler/src/wrapped_types/wrapper.rs index ec0fbb66..55bf1f00 100644 --- a/rustler/src/wrapped_types/wrapper.rs +++ b/rustler/src/wrapped_types/wrapper.rs @@ -9,7 +9,7 @@ impl From for crate::Error { } } -pub(crate) trait Wrapper<'a>: Sized { +pub trait Wrapper<'a>: Sized { const WRAPPED_TYPE: TermType; unsafe fn wrap_ptr_unchecked(env: Env<'a>, ptr: ERL_NIF_TERM) -> Self { diff --git a/rustler_codegen/src/map.rs b/rustler_codegen/src/map.rs index 51c4ef13..7449e401 100644 --- a/rustler_codegen/src/map.rs +++ b/rustler_codegen/src/map.rs @@ -89,7 +89,6 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T where T: rustler::Decoder<'a>, { - use rustler::Encoder; match ::rustler::Decoder::decode(term.map_get(&field)?) { Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( "Could not decode field :{:?} on %{{}}", @@ -123,7 +122,8 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T ctx, quote! { use #atoms_module_name::*; - ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*]).unwrap() + use ::rustler::Wrapper; + env.map_from_term_arrays(&[#(#keys),*], &[#(#values),*]).unwrap().unwrap() }, ) } diff --git a/rustler_tests/native/rustler_test/src/test_map.rs b/rustler_tests/native/rustler_test/src/test_map.rs index 1a6e0b92..8bb52c0f 100644 --- a/rustler_tests/native/rustler_test/src/test_map.rs +++ b/rustler_tests/native/rustler_test/src/test_map.rs @@ -1,4 +1,4 @@ -use rustler::{Encoder, Env, Error, ListIterator, MapIterator, NifResult, Term, Tuple}; +use rustler::{Encoder, Env, Error, ListIterator, Map, MapIterator, NifResult, Term, Tuple}; #[rustler::nif] pub fn sum_map_values(iter: MapIterator) -> NifResult { @@ -43,15 +43,15 @@ pub fn map_from_arrays<'a>( env: Env<'a>, keys: Vec>, values: Vec>, -) -> NifResult> { - Term::map_from_arrays(env, &keys, &values) +) -> NifResult> { + env.map_from_arrays(&keys, &values) } #[rustler::nif] -pub fn map_from_pairs<'a>(env: Env<'a>, pairs: ListIterator<'a>) -> NifResult> { +pub fn map_from_pairs<'a>(env: Env<'a>, pairs: ListIterator<'a>) -> NifResult> { let res: Result, Error> = pairs.map(|x| x.decode()).collect(); - res.and_then(|v| Term::map_from_pairs(env, &v)) + res.and_then(|v| env.map_from_pairs(&v)) } #[rustler::nif]