Skip to content

Commit

Permalink
Add a LuaMutex type to borrow the InsideCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaka committed Dec 29, 2016
1 parent b8112e1 commit 07a4fd9
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 10 deletions.
131 changes: 126 additions & 5 deletions hlua/src/functions_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use std::marker::PhantomData;
use std::fmt::Debug;
use std::mem;
use std::ptr;
use std::sync::Mutex;
use std::sync::MutexGuard;

macro_rules! impl_function {
($name:ident, $($p:ident),*) => (
Expand Down Expand Up @@ -63,6 +65,9 @@ pub trait FunctionExt<P> {
fn call_mut(&mut self, params: P) -> Self::Output;
}

// TODO: with one argument we should require LuaRead<&'a mut InsideCallback> and
// not LuaRead<&'a InsideCallback>

macro_rules! impl_function_ext {
() => (
impl<Z, R> FunctionExt<()> for Function<Z, (), R> where Z: FnMut() -> R {
Expand Down Expand Up @@ -123,7 +128,7 @@ macro_rules! impl_function_ext {
impl<'lua, L, Z, R $(,$p: 'static)+> Push<L> for Function<Z, ($($p,)*), R>
where L: AsMutLua<'lua>,
Z: 'lua + FnMut($($p),*) -> R,
($($p,)*): for<'p> LuaRead<&'p mut InsideCallback>,
($($p,)*): for<'p> LuaRead<&'p InsideCallback>,
R: for<'a> Push<&'a mut InsideCallback> + 'static
{
type Err = Void; // TODO: use `!` instead (https://github.com/rust-lang/rust/issues/35121)
Expand All @@ -149,7 +154,7 @@ macro_rules! impl_function_ext {
impl<'lua, L, Z, R $(,$p: 'static)+> PushOne<L> for Function<Z, ($($p,)*), R>
where L: AsMutLua<'lua>,
Z: 'lua + FnMut($($p),*) -> R,
($($p,)*): for<'p> LuaRead<&'p mut InsideCallback>,
($($p,)*): for<'p> LuaRead<&'p InsideCallback>,
R: for<'a> Push<&'a mut InsideCallback> + 'static
{
}
Expand All @@ -168,12 +173,73 @@ impl_function_ext!(A, B, C, D, E, F, G, H);
impl_function_ext!(A, B, C, D, E, F, G, H, I);
impl_function_ext!(A, B, C, D, E, F, G, H, I, J);

pub struct LuaMutex<'a, T> {
lua: &'a InsideCallback,
index: i32,
marker: PhantomData<T>,
}

impl<'a, T> LuaRead<&'a InsideCallback> for LuaMutex<'a, T>
where T: LuaRead<InsideCallbackLockGuard<'a>>
{
#[inline]
fn lua_read_at_position(lua: &'a InsideCallback, index: i32)
-> Result<Self, &'a InsideCallback>
{
Ok(LuaMutex {
lua: lua,
index: index,
marker: PhantomData,
})
}
}

// TODO: is this necessary?
impl<'a, 'b, T> LuaRead<&'a mut &'b InsideCallback> for LuaMutex<'b, T>
where T: LuaRead<InsideCallbackLockGuard<'b>>
{
#[inline]
fn lua_read_at_position(lua: &'a mut &'b InsideCallback, index: i32)
-> Result<Self, &'a mut &'b InsideCallback>
{
Ok(LuaMutex {
lua: lua,
index: index,
marker: PhantomData,
})
}
}

impl<'a, T> LuaMutex<'a, T>
where T: LuaRead<InsideCallbackLockGuard<'a>>
{
#[inline]
pub fn lock(&self) -> Option<T> {
let lock = self.lua.lock();

match T::lua_read_at_position(lock, self.index) {
Ok(v) => Some(v),
Err(_) => None
}
}
}

/// Opaque type that represents the Lua context when inside a callback.
///
/// Some types (like `Result`) can only be returned from a callback and not written inside a
/// Lua variable. This type is here to enforce this restriction.
pub struct InsideCallback {
lua: LuaContext,
mutex: Mutex<()>,
}

impl InsideCallback {
pub fn lock(&self) -> InsideCallbackLockGuard {
InsideCallbackLockGuard {
lua: self.lua,
guard: self.mutex.lock().unwrap(),
}
}
}

unsafe impl<'a, 'lua> AsLua<'lua> for &'a InsideCallback {
Expand All @@ -197,6 +263,25 @@ unsafe impl<'a, 'lua> AsMutLua<'lua> for &'a mut InsideCallback {
}
}

pub struct InsideCallbackLockGuard<'a> {
lua: LuaContext,
guard: MutexGuard<'a, ()>,
}

unsafe impl<'a, 'lua> AsLua<'lua> for InsideCallbackLockGuard<'a> {
#[inline]
fn as_lua(&self) -> LuaContext {
self.lua
}
}

unsafe impl<'a, 'lua> AsMutLua<'lua> for InsideCallbackLockGuard<'a> {
#[inline]
fn as_mut_lua(&mut self) -> LuaContext {
self.lua
}
}

impl<'a, T, E, P> Push<&'a mut InsideCallback> for Result<T, E>
where T: Push<&'a mut InsideCallback, Err = P> + for<'b> Push<&'b mut &'a mut InsideCallback, Err = P>,
E: Debug
Expand Down Expand Up @@ -232,19 +317,19 @@ impl<'a, T, E, P> PushOne<&'a mut InsideCallback> for Result<T, E>
#[inline]
extern "C" fn wrapper<T, P, R>(lua: *mut ffi::lua_State) -> libc::c_int
where T: FunctionExt<P, Output = R>,
P: for<'p> LuaRead<&'p mut InsideCallback> + 'static,
P: for<'p> LuaRead<&'p InsideCallback>,
R: for<'p> Push<&'p mut InsideCallback>
{
// loading the object that we want to call from the Lua context
let data_raw = unsafe { ffi::lua_touserdata(lua, ffi::lua_upvalueindex(1)) };
let data: &mut T = unsafe { mem::transmute(data_raw) };

// creating a temporary Lua context in order to pass it to push & read functions
let mut tmp_lua = InsideCallback { lua: LuaContext(lua) };
let mut tmp_lua = InsideCallback { lua: LuaContext(lua), mutex: Mutex::new(()) };

// trying to read the arguments
let arguments_count = unsafe { ffi::lua_gettop(lua) } as i32;
let args = match LuaRead::lua_read_at_position(&mut tmp_lua, -arguments_count as libc::c_int) { // TODO: what if the user has the wrong params?
let args = match LuaRead::lua_read_at_position(&tmp_lua, -arguments_count as libc::c_int) { // TODO: what if the user has the wrong params?
Err(_) => {
let err_msg = format!("wrong parameter types for callback function");
match err_msg.push_to_lua(&mut tmp_lua) {
Expand Down Expand Up @@ -273,6 +358,8 @@ extern "C" fn wrapper<T, P, R>(lua: *mut ffi::lua_State) -> libc::c_int
mod tests {
use Lua;
use LuaError;
use LuaFunction;
use LuaMutex;
use function0;
use function1;
use function2;
Expand Down Expand Up @@ -392,4 +479,38 @@ mod tests {
assert_eq!(a, 20)
}

#[test]
fn lua_mutex_basic() {
let mut lua = Lua::new();

lua.set("foo", function1(|a: LuaMutex<LuaFunction<_>>| {
{
let a = a.lock().unwrap();
assert_eq!(a.call::<i32>().unwrap(), 5);
}
}));

lua.execute::<()>("function bar() return 5 end").unwrap();
lua.execute::<()>("foo(bar)").unwrap();
}

/*#[test]
fn lua_mutex_two() {
let mut lua = Lua::new();
lua.set("foo", function2(|a: LuaMutex<LuaFunction<_>>, b: LuaMutex<LuaFunction<_>>| {
{
let a = a.lock().unwrap();
assert_eq!(a.call::<i32>().unwrap(), 5);
}
{
let b = b.lock().unwrap();
assert_eq!(b.call::<i32>().unwrap(), 5);
}
}));
lua.execute::<()>("function bar() return 5 end").unwrap();
lua.execute::<()>("foo(bar)").unwrap();
}*/
}
1 change: 1 addition & 0 deletions hlua/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub use any::AnyLuaValue;
pub use functions_write::{Function, InsideCallback};
pub use functions_write::{function0, function1, function2, function3, function4, function5};
pub use functions_write::{function6, function7, function8, function9, function10};
pub use functions_write::LuaMutex;
pub use lua_functions::LuaFunction;
pub use lua_functions::LuaFunctionCallError;
pub use lua_functions::{LuaCode, LuaCodeFromReader};
Expand Down
9 changes: 4 additions & 5 deletions hlua/src/tuples.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use AsMutLua;
use AsLua;

use Push;
use PushOne;
Expand All @@ -9,7 +8,7 @@ use Void;

macro_rules! tuple_impl {
($ty:ident) => (
impl<'lua, LU, $ty> Push<LU> for ($ty,) where LU: AsMutLua<'lua>, $ty: Push<LU> {
impl<LU, $ty> Push<LU> for ($ty,) where $ty: Push<LU> {
type Err = <$ty as Push<LU>>::Err;

#[inline]
Expand All @@ -18,10 +17,10 @@ macro_rules! tuple_impl {
}
}

impl<'lua, LU, $ty> PushOne<LU> for ($ty,) where LU: AsMutLua<'lua>, $ty: PushOne<LU> {
impl<LU, $ty> PushOne<LU> for ($ty,) where $ty: PushOne<LU> {
}

impl<'lua, LU, $ty> LuaRead<LU> for ($ty,) where LU: AsMutLua<'lua>, $ty: LuaRead<LU> {
impl<LU, $ty> LuaRead<LU> for ($ty,) where $ty: LuaRead<LU> {
#[inline]
fn lua_read_at_position(lua: LU, index: i32) -> Result<($ty,), LU> {
LuaRead::lua_read_at_position(lua, index).map(|v| (v,))
Expand Down Expand Up @@ -74,7 +73,7 @@ macro_rules! tuple_impl {
#[allow(unused_assignments)]
#[allow(non_snake_case)]
impl<'lua, LU, $first: for<'a> LuaRead<&'a mut LU>, $($other: for<'a> LuaRead<&'a mut LU>),+>
LuaRead<LU> for ($first, $($other),+) where LU: AsLua<'lua>
LuaRead<LU> for ($first, $($other),+)
{
#[inline]
fn lua_read_at_position(mut lua: LU, index: i32) -> Result<($first, $($other),+), LU> {
Expand Down

0 comments on commit 07a4fd9

Please sign in to comment.