Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add a LuaMutex type to borrow the InsideCallback #118

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 146 additions & 24 deletions hlua/src/functions_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use std::marker::PhantomData;
use std::fmt::Debug;
use std::mem;
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;

macro_rules! impl_function {
($name:ident, $($p:ident),*) => (
Expand Down Expand Up @@ -151,6 +154,9 @@ pub trait FunctionExt<P> {
fn call_mut(&mut self, params: P) -> Self::Output;
}

// TODO: with one argument we should require LuaRead<&'a mut InsideCallback<'lua>> and
// not LuaRead<&'a InsideCallback<'lua>>

macro_rules! impl_function_ext {
() => (
impl<Z, R> FunctionExt<()> for Function<Z, (), R> where Z: FnMut() -> R {
Expand All @@ -166,7 +172,7 @@ macro_rules! impl_function_ext {
impl<'lua, L, Z, R> Push<L> for Function<Z, (), R>
where L: AsMutLua<'lua>,
Z: 'lua + FnMut() -> R,
R: for<'a> Push<&'a mut InsideCallback> + 'static
R: for<'a> Push<&'a mut InsideCallback<'lua>> + 'static
{
type Err = Void; // TODO: use `!` instead (https://github.com/rust-lang/rust/issues/35121)

Expand All @@ -191,7 +197,7 @@ macro_rules! impl_function_ext {
impl<'lua, L, Z, R> PushOne<L> for Function<Z, (), R>
where L: AsMutLua<'lua>,
Z: 'lua + FnMut() -> R,
R: for<'a> Push<&'a mut InsideCallback> + 'static
R: for<'a> Push<&'a mut InsideCallback<'lua>> + 'static
{
}
);
Expand All @@ -208,11 +214,11 @@ macro_rules! impl_function_ext {
}
}

impl<'lua, L, Z, R $(,$p: 'static)+> Push<L> for Function<Z, ($($p,)*), R>
impl<'lua, L, Z, R $(,$p)+> Push<L> for Function<Z, ($($p,)*), R>
where L: AsMutLua<'lua>,
Z: 'lua + FnMut($($p),*) -> R,
($($p,)*): for<'p> LuaRead<&'p mut InsideCallback>,
R: for<'a> Push<&'a mut InsideCallback> + 'static
($($p,)*): LuaRead<Arc<InsideCallback<'lua>>>,
R: for<'a> Push<&'a mut InsideCallback<'lua>> + 'static
{
type Err = Void; // TODO: use `!` instead (https://github.com/rust-lang/rust/issues/35121)

Expand All @@ -234,11 +240,11 @@ macro_rules! impl_function_ext {
}
}

impl<'lua, L, Z, R $(,$p: 'static)+> PushOne<L> for Function<Z, ($($p,)*), R>
impl<'lua, L, Z, R $(,$p)+> PushOne<L> for Function<Z, ($($p,)*), R>
where L: AsMutLua<'lua>,
Z: 'lua + FnMut($($p),*) -> R,
($($p,)*): for<'p> LuaRead<&'p mut InsideCallback>,
R: for<'a> Push<&'a mut InsideCallback> + 'static
($($p,)*): LuaRead<Arc<InsideCallback<'lua>>>,
R: for<'a> Push<&'a mut InsideCallback<'lua>> + 'static
{
}
)
Expand All @@ -256,43 +262,120 @@ impl_function_ext!(A, B, C, D, E, F, G, H);
impl_function_ext!(A, B, C, D, E, F, G, H, I);
impl_function_ext!(A, B, C, D, E, F, G, H, I, J);

pub struct LuaMutex<T> {
lua: Arc<InsideCallback<'static>>, // TODO: I couldn't make it work for non-'static
index: i32,
marker: PhantomData<T>,
}

impl<T> LuaRead<Arc<InsideCallback<'static>>> for LuaMutex<T>
where T: LuaRead<InsideCallbackLockGuard<'static>>
{
#[inline]
fn lua_read_at_position(lua: Arc<InsideCallback<'static>>, index: i32)
-> Result<Self, Arc<InsideCallback<'static>>>
{
Ok(LuaMutex {
lua: lua,
index: index,
marker: PhantomData,
})
}
}

impl<T> LuaMutex<T>
where T: LuaRead<InsideCallbackLockGuard<'static>>
{
#[inline]
pub fn lock(&self) -> Option<T> {
let lock = InsideCallback::lock(self.lua.clone());

match T::lua_read_at_position(lock, self.index) {
Ok(v) => Some(v),
Err(_) => None
}
}
}

/// Opaque type that represents the Lua context when inside a callback.
///
/// Some types (like `Result`) can only be returned from a callback and not written inside a
/// Lua variable. This type is here to enforce this restriction.
pub struct InsideCallback {
pub struct InsideCallback<'lua> {
lua: LuaContext,
mutex: AtomicBool,
marker: PhantomData<&'lua ()>,
}

unsafe impl<'a, 'lua> AsLua<'lua> for &'a InsideCallback {
impl<'lua> InsideCallback<'lua> {
#[inline]
pub fn lock(me: Arc<InsideCallback<'lua>>) -> InsideCallbackLockGuard<'lua> {
let old = me.mutex.swap(true, Ordering::SeqCst);
if old {
panic!("Can't lock the InsideCallback twice simultaneously");
}

InsideCallbackLockGuard {
lua: me
}
}
}

unsafe impl<'a, 'lua> AsLua<'lua> for Arc<InsideCallback<'lua>> {
#[inline]
fn as_lua(&self) -> LuaContext {
self.lua
}
}

unsafe impl<'a, 'lua> AsLua<'lua> for &'a mut InsideCallback {
unsafe impl<'a, 'lua> AsLua<'lua> for &'a mut InsideCallback<'lua> {
#[inline]
fn as_lua(&self) -> LuaContext {
self.lua
}
}

unsafe impl<'a, 'lua> AsMutLua<'lua> for &'a mut InsideCallback {
unsafe impl<'a, 'lua> AsMutLua<'lua> for &'a mut InsideCallback<'lua> {
#[inline]
fn as_mut_lua(&mut self) -> LuaContext {
self.lua
}
}

impl<'a, T, E, P> Push<&'a mut InsideCallback> for Result<T, E>
where T: Push<&'a mut InsideCallback, Err = P> + for<'b> Push<&'b mut &'a mut InsideCallback, Err = P>,
pub struct InsideCallbackLockGuard<'lua> {
lua: Arc<InsideCallback<'lua>>
}

unsafe impl<'lua> AsLua<'lua> for InsideCallbackLockGuard<'lua> {
#[inline]
fn as_lua(&self) -> LuaContext {
self.lua.lua
}
}

unsafe impl<'lua> AsMutLua<'lua> for InsideCallbackLockGuard<'lua> {
#[inline]
fn as_mut_lua(&mut self) -> LuaContext {
self.lua.lua
}
}

impl<'lua> Drop for InsideCallbackLockGuard<'lua> {
#[inline]
fn drop(&mut self) {
let old = self.lua.mutex.swap(false, Ordering::SeqCst);
debug_assert_eq!(old, true);
}
}

impl<'a, 'lua, T, E, P> Push<&'a mut InsideCallback<'lua>> for Result<T, E>
where T: Push<&'a mut InsideCallback<'lua>, Err = P> + for<'b> Push<&'b mut &'a mut InsideCallback<'lua>, Err = P>,
E: Debug
{
type Err = P;

#[inline]
fn push_to_lua(self, mut lua: &'a mut InsideCallback) -> Result<PushGuard<&'a mut InsideCallback>, (P, &'a mut InsideCallback)> {
fn push_to_lua(self, mut lua: &'a mut InsideCallback<'lua>) -> Result<PushGuard<&'a mut InsideCallback<'lua>>, (P, &'a mut InsideCallback<'lua>)> {
unsafe {
match self {
Ok(val) => val.push_to_lua(lua),
Expand All @@ -310,32 +393,36 @@ impl<'a, T, E, P> Push<&'a mut InsideCallback> for Result<T, E>
}
}

impl<'a, T, E, P> PushOne<&'a mut InsideCallback> for Result<T, E>
where T: PushOne<&'a mut InsideCallback, Err = P> + for<'b> PushOne<&'b mut &'a mut InsideCallback, Err = P>,
impl<'a, 'lua, T, E, P> PushOne<&'a mut InsideCallback<'lua>> for Result<T, E>
where T: PushOne<&'a mut InsideCallback<'lua>, Err = P> + for<'b> PushOne<&'b mut &'a mut InsideCallback<'lua>, Err = P>,
E: Debug
{
}

// this function is called when Lua wants to call one of our functions
#[inline]
extern "C" fn wrapper<T, P, R>(lua: *mut ffi::lua_State) -> libc::c_int
extern "C" fn wrapper<'lua, T, P, R>(lua: *mut ffi::lua_State) -> libc::c_int
where T: FunctionExt<P, Output = R>,
P: for<'p> LuaRead<&'p mut InsideCallback> + 'static,
R: for<'p> Push<&'p mut InsideCallback>
P: LuaRead<Arc<InsideCallback<'lua>>>,
R: for<'p> Push<&'p mut InsideCallback<'lua>>
{
// loading the object that we want to call from the Lua context
let data_raw = unsafe { ffi::lua_touserdata(lua, ffi::lua_upvalueindex(1)) };
let data: &mut T = unsafe { mem::transmute(data_raw) };

// creating a temporary Lua context in order to pass it to push & read functions
let mut tmp_lua = InsideCallback { lua: LuaContext(lua) };
let mut tmp_lua = Arc::new(InsideCallback {
lua: LuaContext(lua),
mutex: AtomicBool::new(false),
marker: PhantomData,
});

// trying to read the arguments
let arguments_count = unsafe { ffi::lua_gettop(lua) } as i32;
let args = match LuaRead::lua_read_at_position(&mut tmp_lua, -arguments_count as libc::c_int) { // TODO: what if the user has the wrong params?
let args = match LuaRead::lua_read_at_position(tmp_lua.clone(), -arguments_count as libc::c_int) { // TODO: what if the user has the wrong params?
Err(_) => {
let err_msg = format!("wrong parameter types for callback function");
match err_msg.push_to_lua(&mut tmp_lua) {
match err_msg.push_to_lua(Arc::get_mut(&mut tmp_lua).unwrap()) {
Ok(p) => p.forget(),
Err(_) => unreachable!(),
};
Expand All @@ -350,7 +437,7 @@ extern "C" fn wrapper<T, P, R>(lua: *mut ffi::lua_State) -> libc::c_int
let ret_value = data.call_mut(args);

// pushing back the result of the function on the stack
let nb = match ret_value.push_to_lua(&mut tmp_lua) {
let nb = match ret_value.push_to_lua(Arc::get_mut(&mut tmp_lua).unwrap()) {
Ok(p) => p.forget(),
Err(_) => panic!(), // TODO: wrong
};
Expand All @@ -361,6 +448,8 @@ extern "C" fn wrapper<T, P, R>(lua: *mut ffi::lua_State) -> libc::c_int
mod tests {
use Lua;
use LuaError;
use LuaFunction;
use LuaMutex;
use function0;
use function1;
use function2;
Expand Down Expand Up @@ -480,4 +569,37 @@ mod tests {
assert_eq!(a, 20)
}

#[test]
fn lua_mutex_basic() {
let mut lua = Lua::new();

lua.set("foo", function1(|a: LuaMutex<LuaFunction<_>>| {
let mut a = a.lock().unwrap();
assert_eq!(a.call::<i32>().unwrap(), 5);
}));

lua.execute::<()>("function bar() return 5 end").unwrap();
lua.execute::<()>("foo(bar)").unwrap();
}

/* TODO: make compile
#[test]
fn lua_mutex_two() {
let mut lua = Lua::new();

lua.set("foo", function2(|a: LuaMutex<LuaFunction<_>>, b: LuaMutex<LuaFunction<_>>| {
{
let a = a.lock().unwrap();
assert_eq!(a.call::<i32>().unwrap(), 5);
}

{
let b = b.lock().unwrap();
assert_eq!(b.call::<i32>().unwrap(), 5);
}
}));

lua.execute::<()>("function bar() return 5 end").unwrap();
lua.execute::<()>("foo(bar)").unwrap();
}*/
}
15 changes: 10 additions & 5 deletions hlua/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ pub use any::AnyLuaValue;
pub use functions_write::{Function, InsideCallback};
pub use functions_write::{function0, function1, function2, function3, function4, function5};
pub use functions_write::{function6, function7, function8, function9, function10};
pub use functions_write::LuaMutex;
pub use lua_functions::LuaFunction;
pub use lua_functions::LuaFunctionCallError;
pub use lua_functions::{LuaCode, LuaCodeFromReader};
pub use lua_tables::LuaTable;
pub use lua_tables::LuaTableIterator;
pub use tuples::TuplePushError;
pub use userdata::UserdataOnStack;
pub use userdata::{push_userdata, read_userdata};
pub use userdata::{push_userdata, read_userdata, read_mut_userdata};
pub use values::StringInLua;

mod any;
Expand Down Expand Up @@ -201,15 +202,19 @@ impl<'lua, L> PushGuard<L>
}
}

/// Trait for objects that have access to a Lua context. When using a context returned by a
/// `AsLua`, you are not allowed to modify the stack.
/// Trait for objects that have access to a Lua context.
///
/// When using a context returned by a `AsLua`, you are not allowed to modify the stack and
/// multiple codes may access that same stack at the same time.
// TODO: the lifetime should be an associated lifetime instead
pub unsafe trait AsLua<'lua> {
fn as_lua(&self) -> LuaContext;
}

/// Trait for objects that have access to a Lua context. You are allowed to modify the stack, but
/// it must be in the same state as it was when you started.
/// Trait for objects that have access to a Lua context.
///
/// You have exclusive access to the stack, but it must be in the same state as it was when you
/// started.
// TODO: the lifetime should be an associated lifetime instead
pub unsafe trait AsMutLua<'lua>: AsLua<'lua> {
/// Returns the raw Lua context.
Expand Down
47 changes: 16 additions & 31 deletions hlua/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,30 @@ macro_rules! implement_lua_push {
#[macro_export]
macro_rules! implement_lua_read {
($ty:ty) => {
impl<'s, 'c> hlua::LuaRead<&'c mut hlua::InsideCallback> for &'s mut $ty {
impl<'l, 'lua, L> hlua::LuaRead<&'l L> for &'l $ty
where L: hlua::AsLua<'lua>
{
#[inline]
fn lua_read_at_position(lua: &'c mut hlua::InsideCallback, index: i32) -> Result<&'s mut $ty, &'c mut hlua::InsideCallback> {
// FIXME:
unsafe { ::std::mem::transmute($crate::read_userdata::<$ty>(lua, index)) }
fn lua_read_at_position(lua: &'l L, index: i32) -> Result<&'l $ty, &'l L> {
hlua::read_userdata(lua, index)
}
}

impl<'s, 'c> hlua::LuaRead<&'c mut hlua::InsideCallback> for &'s $ty {
/*impl<'l, 'lua, L> hlua::LuaRead<&'l mut L> for &'l $ty
where L: hlua::AsMutLua<'lua>
{
#[inline]
fn lua_read_at_position(lua: &'c mut hlua::InsideCallback, index: i32) -> Result<&'s $ty, &'c mut hlua::InsideCallback> {
// FIXME:
unsafe { ::std::mem::transmute($crate::read_userdata::<$ty>(lua, index)) }
fn lua_read_at_position(lua: &'l mut L, index: i32) -> Result<&'l $ty, &'l mut L> {
hlua::read_userdata(lua, index)
}
}

impl<'s, 'b, 'c> hlua::LuaRead<&'b mut &'c mut hlua::InsideCallback> for &'s mut $ty {
#[inline]
fn lua_read_at_position(lua: &'b mut &'c mut hlua::InsideCallback, index: i32) -> Result<&'s mut $ty, &'b mut &'c mut hlua::InsideCallback> {
let ptr_lua = lua as *mut &mut hlua::InsideCallback;
let deref_lua = unsafe { ::std::ptr::read(ptr_lua) };
let res = Self::lua_read_at_position(deref_lua, index);
match res {
Ok(x) => Ok(x),
_ => Err(lua)
}
}
}
}*/

impl<'s, 'b, 'c> hlua::LuaRead<&'b mut &'c mut hlua::InsideCallback> for &'s $ty {
impl<'l, 'lua, L> hlua::LuaRead<&'l mut L> for &'l mut $ty
where L: hlua::AsMutLua<'lua>
{
#[inline]
fn lua_read_at_position(lua: &'b mut &'c mut hlua::InsideCallback, index: i32) -> Result<&'s $ty, &'b mut &'c mut hlua::InsideCallback> {
let ptr_lua = lua as *mut &mut hlua::InsideCallback;
let deref_lua = unsafe { ::std::ptr::read(ptr_lua) };
let res = Self::lua_read_at_position(deref_lua, index);
match res {
Ok(x) => Ok(x),
_ => Err(lua)
}
fn lua_read_at_position(lua: &'l mut L, index: i32) -> Result<&'l mut $ty, &'l mut L> {
hlua::read_mut_userdata(lua, index)
}
}
};
Expand Down
Loading