diff --git a/.circleci/config.yml b/.circleci/config.yml index 3c6dc579..e14ef9bb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -20,18 +20,18 @@ jobs: - restore_cache: keys: - cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }} - - run: - name: Check Formatting - command: | - rustup component add rustfmt - rustfmt --version - cargo fmt --all -- --check --color=auto - run: name: Build all targets command: cargo build --all --all-targets - run: name: Run all tests command: cargo test --all + - run: + name: Check Formatting + command: | + rustup component add rustfmt + rustfmt --version + cargo fmt --all -- --check --color=auto - save_cache: paths: - /usr/local/cargo/registry @@ -58,18 +58,18 @@ jobs: - restore_cache: keys: - cargo-cache-lua53-{{ arch }}-{{ checksum "Cargo.lock" }} - - run: - name: Check Formatting - command: | - rustup component add rustfmt - rustfmt --version - cargo fmt --all -- --check --color=auto - run: name: Build all targets command: cargo build --no-default-features --features=builtin-lua53 --all --all-targets - run: name: Run all tests command: cargo test --no-default-features --features=builtin-lua53 --all + - run: + name: Check Formatting + command: | + rustup component add rustfmt + rustfmt --version + cargo fmt --all -- --check --color=auto - save_cache: paths: - /usr/local/cargo/registry @@ -102,18 +102,18 @@ jobs: - restore_cache: keys: - cargo-cache-lua51-{{ arch }}-{{ checksum "Cargo.lock" }} - - run: - name: Check Formatting - command: | - rustup component add rustfmt - rustfmt --version - cargo fmt --all -- --check --color=auto - run: name: Build all targets command: cargo build --no-default-features --features=system-lua51 --all --all-targets - run: name: Run all tests command: cargo test --no-default-features --features=system-lua51 --all + - run: + name: Check Formatting + command: | + rustup component add rustfmt + rustfmt --version + cargo fmt --all -- --check --color=auto - save_cache: paths: - /usr/local/cargo/registry diff --git a/crates/rlua-lua51-sys/src/lib.rs b/crates/rlua-lua51-sys/src/lib.rs index 4fe7f43e..db811f47 100644 --- a/crates/rlua-lua51-sys/src/lib.rs +++ b/crates/rlua-lua51-sys/src/lib.rs @@ -68,7 +68,7 @@ pub use { pub use { bindings::lua_Alloc, bindings::lua_CFunction, bindings::lua_Debug, bindings::lua_Integer, - bindings::lua_Number, bindings::lua_State, + bindings::lua_Number, bindings::lua_State, bindings::lua_Writer, }; /* @@ -143,8 +143,8 @@ pub use bindings::lua_gc; ** miscellaneous functions */ pub use { - bindings::lua_concat, bindings::lua_error, bindings::lua_getallocf, bindings::lua_next, - bindings::lua_setallocf, + bindings::lua_concat, bindings::lua_dump, bindings::lua_error, bindings::lua_getallocf, + bindings::lua_next, bindings::lua_setallocf, }; /* diff --git a/src/context.rs b/src/context.rs index 18e0ed39..5f5ff0f0 100644 --- a/src/context.rs +++ b/src/context.rs @@ -856,10 +856,16 @@ impl<'lua> Context<'lua> { source: &[u8], name: Option<&CString>, env: Option>, + allow_binary: bool, ) -> Result> { unsafe { let _sg = StackGuard::new(self.state); assert_stack(self.state, 1); + let mode = if allow_binary { + cstr!("bt") + } else { + cstr!("t") + }; match if let Some(name) = name { loadbufferx( @@ -867,7 +873,7 @@ impl<'lua> Context<'lua> { source.as_ptr() as *const c_char, source.len(), name.as_ptr() as *const c_char, - cstr!("t"), + mode, ) } else { loadbufferx( @@ -875,7 +881,7 @@ impl<'lua> Context<'lua> { source.as_ptr() as *const c_char, source.len(), ptr::null(), - cstr!("t"), + mode, ) } { ffi::LUA_OK => { @@ -956,10 +962,12 @@ impl<'lua, 'a> Chunk<'lua, 'a> { // actual lua repl does. let mut expression_source = b"return ".to_vec(); expression_source.extend(self.source); - if let Ok(function) = - self.context - .load_chunk(&expression_source, self.name.as_ref(), self.env.clone()) - { + if let Ok(function) = self.context.load_chunk( + &expression_source, + self.name.as_ref(), + self.env.clone(), + false, + ) { function.call(()) } else { self.call(()) @@ -978,7 +986,20 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// This simply compiles the chunk without actually executing it. pub fn into_function(self) -> Result> { self.context - .load_chunk(self.source, self.name.as_ref(), self.env) + .load_chunk(self.source, self.name.as_ref(), self.env, false) + } + + /// Load this chunk into a regular `Function`. + /// + /// This simply compiles the chunk without actually executing it. + /// Unlike `into_function`, this method allows loading code previously + /// compiled and saved with `Function::dump` or `string.dump()`. + /// This method is unsafe because there is no check that the precompiled + /// Lua code is valid; if it is not this may cause a crash or other + /// undefined behaviour. + pub unsafe fn into_function_allow_binary(self) -> Result> { + self.context + .load_chunk(self.source, self.name.as_ref(), self.env, true) } } diff --git a/src/function.rs b/src/function.rs index 5dc2a888..cb98b756 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,11 +1,14 @@ use std::os::raw::c_int; use std::ptr; +use libc::c_void; + use crate::error::{Error, Result}; use crate::ffi; use crate::types::LuaRef; use crate::util::{ - assert_stack, check_stack, error_traceback, pop_error, protect_lua_closure, rotate, StackGuard, + assert_stack, check_stack, dump, error_traceback, pop_error, protect_lua_closure, rotate, + StackGuard, }; use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti}; @@ -161,4 +164,59 @@ impl<'lua> Function<'lua> { Ok(Function(lua.pop_ref())) } } + + /// Dumps the compiled representation of the function into a binary blob, + /// which can later be loaded using the unsafe Chunk::into_function_allow_binary(). + /// + /// # Examples + /// + /// ``` + /// # use rlua::{Lua, Function, Result}; + /// # fn main() -> Result<()> { + /// # Lua::new().context(|lua_context| { + /// let add2: Function = lua_context.load(r#" + /// function(a) + /// return a + 2 + /// end + /// "#).eval()?; + /// + /// let dumped = add2.dump()?; + /// + /// let reloaded = unsafe { + /// lua_context.load(&dumped) + /// .into_function_allow_binary()? + /// }; + /// assert_eq!(reloaded.call::<_, u32>(7)?, 7+2); + /// + /// # Ok(()) + /// # }) + /// # } + /// ``` + pub fn dump(&self) -> Result> { + unsafe extern "C" fn writer( + _state: *mut ffi::lua_State, + p: *const c_void, + sz: usize, + ud: *mut c_void, + ) -> c_int { + let input_slice = std::slice::from_raw_parts(p as *const u8, sz); + let vec = &mut *(ud as *mut Vec); + vec.extend_from_slice(input_slice); + 0 + } + let lua = self.0.lua; + let mut bytes = Vec::new(); + unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 1)?; + let bytes_ptr = &mut bytes as *mut _; + protect_lua_closure(lua.state, 0, 0, |state| { + lua.push_ref(&self.0); + let dump_result = dump(state, Some(writer), bytes_ptr as *mut c_void, 0); + // It can only return an error from our writer. + debug_assert_eq!(dump_result, 0); + })?; + } + Ok(bytes) + } } diff --git a/src/lua.rs b/src/lua.rs index 086e925e..0077bc3d 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -19,8 +19,8 @@ use crate::hook::{hook_proc, Debug, HookTriggers}; use crate::markers::NoRefUnwindSafe; use crate::types::Callback; use crate::util::{ - assert_stack, init_error_registry, protect_lua_closure, push_globaltable, requiref, safe_pcall, - safe_xpcall, userdata_destructor, + assert_stack, dostring, init_error_registry, protect_lua_closure, push_globaltable, rawlen, + requiref, safe_pcall, safe_xpcall, userdata_destructor, }; bitflags! { @@ -64,8 +64,12 @@ bitflags! { /// Flags describing the set of lua modules to load. pub struct InitFlags: u32 { const PCALL_WRAPPERS = 0x1; + const LOAD_WRAPPERS = 0x2; + const REMOVE_LOADLIB = 0x4; - const DEFAULT = InitFlags::PCALL_WRAPPERS.bits; + const DEFAULT = InitFlags::PCALL_WRAPPERS.bits | + InitFlags::LOAD_WRAPPERS.bits | + InitFlags::REMOVE_LOADLIB.bits; const NONE = 0; } } @@ -568,7 +572,6 @@ unsafe fn create_lua(lua_mod_to_load: StdLib, init_flags: InitFlags) -> Lua { ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX); // Override pcall and xpcall with versions that cannot be used to catch rust panics. - if init_flags.contains(InitFlags::PCALL_WRAPPERS) { push_globaltable(state); @@ -583,6 +586,160 @@ unsafe fn create_lua(lua_mod_to_load: StdLib, init_flags: InitFlags) -> Lua { ffi::lua_pop(state, 1); } + // Override dofile, load, and loadfile with versions that won't load + // binary files. + if init_flags.contains(InitFlags::LOAD_WRAPPERS) { + // These are easier to override in Lua. + #[cfg(any(rlua_lua53, rlua_lua54))] + let wrapload = r#" + do + -- load(chunk [, chunkname [, mode [, env]]]) + local real_load = load + load = function(...) + local args = table.pack(...) + args[3] = "t" + if args.n < 3 then args.n = 3 end + return real_load(table.unpack(args)) + end + + -- loadfile ([filename [, mode [, env]]]) + local real_loadfile = loadfile + local real_error = error + loadfile = function(...) + local args = table.pack(...) + args[2] = "t" + if args.n < 2 then args.n = 2 end + return real_loadfile(table.unpack(args)) + end + + -- dofile([filename]) + local real_dofile = dofile + dofile = function(filename) + -- Note: this is the wrapped loadfile above + local chunk = loadfile(filename) + if chunk then + return chunk() + else + real_error("rlua dofile: attempt to load bytecode") + end + end + end + "#; + #[cfg(rlua_lua51)] + let wrapload = r#" + do + -- load(chunk [, chunkname]) + local real_load = load + -- save type() in case user code replaces it + local real_type = type + local real_error = error + load = function(func, chunkname) + local first_chunk = true + local wrap_func = function() + if not first_chunk then + return func() + else + local data = func() + if data == nil then return nil end + assert(real_type(data) == "string") + if data:len() > 0 then + if data:byte(1) == 27 then + real_error("rlua load: loading binary chunks is not allowed") + end + first_chunk = false + end + return data + end + end + return real_load(wrap_func, chunkname) + end + + -- loadstring(string [, chunkname]) + local real_loadstring = loadstring + loadstring = function(s, chunkname) + if type(s) ~= "string" then + real_error("rlua loadstring: string expected.") + elseif s:byte(1) == 27 then + -- This is a binary chunk, so disallow + return nil, "rlua loadstring: loading binary chunks is not allowed" + else + return real_loadstring(s, chunkname) + end + end + + -- loadfile ([filename]) + local real_loadfile = loadfile + local real_io_open = io.open + loadfile = function(filename) + local f, err = real_io_open(filename, "rb") + if not f then + return nil, err + end + local first_chunk = true + local func = function() + return f:read(4096) + end + -- Note: the safe load from above. + return load(func, filename) + end + + -- dofile([filename]) + local real_dofile = dofile + dofile = function(filename) + -- Note: this is the wrapped loadfile above + local chunk = loadfile(filename) + if chunk then + return chunk() + else + real_error("rlua dofile: attempt to load bytecode") + end + end + end + "#; + + let result = dostring(state, wrapload); + if result != 0 { + use std::ffi::CStr; + let errmsg = ffi::lua_tostring(state, -1); + eprintln!( + "Internal error running setup code: {:?}", + CStr::from_ptr(errmsg) + ); + } + assert_eq!(result, 0); + } + + if init_flags.contains(InitFlags::REMOVE_LOADLIB) { + ffi::lua_getglobal(state, cstr!("package")); + let t = ffi::lua_type(state, -1); + if t == ffi::LUA_TTABLE { + // Package is loaded. Remove loadlib. + ffi::lua_pushnil(state); + ffi::lua_setfield(state, -2, cstr!("loadlib")); + + #[cfg(rlua_lua51)] + let searchers_name = cstr!("loaders"); + #[cfg(any(rlua_lua53, rlua_lua54))] + let searchers_name = cstr!("searchers"); + + ffi::lua_getfield(state, -1, searchers_name); + debug_assert_eq!(ffi::lua_type(state, -1), ffi::LUA_TTABLE); + debug_assert_eq!(rawlen(state, -1), 4); + // Remove the searchers/loaders which will load C libraries. + ffi::lua_pushnil(state); + ffi::lua_rawseti(state, -2, 4); + ffi::lua_pushnil(state); + ffi::lua_rawseti(state, -2, 3); + + ffi::lua_pop(state, 1); + } else { + // Assume it's not present otherwise. + assert_eq!(t, ffi::LUA_TNIL); + } + // Pop the package (or nil) off the stack. + ffi::lua_pop(state, 1); + } + // Create ref stack thread and place it in the registry to prevent it from being garbage // collected. diff --git a/src/userdata.rs b/src/userdata.rs index ac29c684..980ab769 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -342,7 +342,7 @@ impl<'lua> AnyUserData<'lua> { /// Sets an associated value to this `AnyUserData`. /// - /// The value may be any Lua value whatsoever, and can be retrieved with [`get_user_value`]. + /// The value may be any Lua value whatsoever, and can be retrieved with [`AnyUserData::get_user_value`]. /// /// Equivalent to set_i_user_value(v, 1) /// diff --git a/src/util.rs b/src/util.rs index b425ee1b..9a8deb9a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -96,7 +96,7 @@ pub unsafe fn protect_lua( // given function return type is not the return value count, instead the inner function return // values are assumed to match the `nresults` param. Internally uses 3 extra stack spaces, and does // not call checkstack. Provided function must *not* panic, and since it will generally be -// lonjmping, should not contain any values that implement Drop. +// longjmping, should not contain any values that implement Drop. pub unsafe fn protect_lua_closure( state: *mut ffi::lua_State, nargs: c_int, @@ -696,6 +696,21 @@ pub unsafe fn loadbufferx( ffi::luaL_loadbuffer(state, buf, size, name) } +pub unsafe fn dostring(state: *mut ffi::lua_State, s: &str) -> c_int { + let load_result = loadbufferx( + state, + s.as_ptr() as *const c_char, + s.len(), + cstr!(""), + cstr!("t"), + ); + if load_result == ffi::LUA_OK { + ffi::lua_pcall(state, 0, ffi::LUA_MULTRET, 0) + } else { + load_result + } +} + #[cfg(any(rlua_lua53, rlua_lua54))] // Like luaL_requiref but doesn't leave the module on the stack. pub unsafe fn requiref( @@ -771,6 +786,19 @@ pub unsafe fn traceback( ffi::lua_pushstring(push_state, msg); } +#[cfg(any(rlua_lua53, rlua_lua54))] +pub use ffi::lua_dump as dump; + +#[cfg(rlua_lua51)] +pub unsafe fn dump( + state: *mut ffi::lua_State, + writer: ffi::lua_Writer, + data: *mut c_void, + _strip: c_int, +) -> c_int { + ffi::lua_dump(state, writer, data) +} + // In the context of a lua callback, this will call the given function and if the given function // returns an error, *or if the given function panics*, this will result in a call to lua_error (a // longjmp). The error or panic is wrapped in such a way that when calling pop_error back on diff --git a/tests/tests.rs b/tests/tests.rs index f8ebb0cc..264f68ac 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,11 +1,12 @@ +use bstr::BString; use std::iter::FromIterator; use std::panic::catch_unwind; use std::sync::Arc; use std::{error, f32, f64, fmt}; use rlua::{ - Error, ExternalError, Function, /* InitFlags, */ Lua, Nil, Result, StdLib, String, Table, - UserData, Value, Variadic, + Error, ExternalError, Function, InitFlags, Lua, Nil, Result, StdLib, String, Table, UserData, + Value, Variadic, }; #[test] @@ -406,6 +407,505 @@ fn test_error_nopcall_wrap() { } */ +#[test] +fn test_load_wrappers() { + Lua::new().context(|lua| { + let globals = lua.globals(); + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + lua.load( + r#" + assert(type(binchunk) == "string") + local binchunk_copy = binchunk + chunk = load(function () + local result = binchunk_copy + binchunk_copy = nil + return result + end, "bad", "bt") + print(chunk) + assert(chunk == nil) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + lua.load( + r#" + local s = "x = x + 4" + local function loader() + local result = s + s = nil + return result + end + chunk = load(loader) + assert(chunk ~= nil) + chunk() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }); +} + +#[test] +fn test_no_load_wrappers() { + unsafe { + Lua::unsafe_new_with_flags( + StdLib::ALL_NO_DEBUG, + InitFlags::DEFAULT - InitFlags::LOAD_WRAPPERS, + ) + .context(|lua| { + let globals = lua.globals(); + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + lua.load( + r#" + assert(type(binchunk) == "string") + assert(binchunk:byte(1) == 27) + local stringsource = binchunk + local loader = function () + local result = stringsource + stringsource = nil + return result + end + if _VERSION ~= "Lua 5.1" then + -- Lua 5.1 doesn't support the mode parameter. + chunk = load(binchunk, "fail", "t") + assert(chunk == nil) + end + chunk = load(loader, "good", "bt") + assert(chunk ~= nil) + chunk() + stringsource = "x = x + 3" + chunk = load(loader) + chunk() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }) + }; +} + +#[test] +fn test_loadfile_wrappers() { + let mut tmppath = std::env::temp_dir(); + tmppath.push("test_loadfile_wrappers.lua"); + + Lua::new().context(|lua| { + let globals = lua.globals(); + globals.set("filename", tmppath.to_str().unwrap()).unwrap(); + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + let binchunk = globals.get::<_, BString>("binchunk").unwrap(); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + std::fs::write(&tmppath, binchunk).unwrap(); + lua.load( + r#" + chunk = loadfile(filename) + assert(chunk == nil) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + std::fs::write(&tmppath, "x = x + 4").unwrap(); + lua.load( + r#" + chunk = loadfile(filename) + assert(chunk ~= nil) + chunk() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }); +} + +#[test] +fn test_no_loadfile_wrappers() { + let mut tmppath = std::env::temp_dir(); + let mut tmppath2 = tmppath.clone(); + tmppath.push("test_no_loadfile_wrappers.lua"); + tmppath2.push("test_no_loadfile_wrappers2.lua"); + + unsafe { + Lua::unsafe_new_with_flags( + StdLib::ALL_NO_DEBUG, + InitFlags::DEFAULT - InitFlags::LOAD_WRAPPERS, + ) + .context(|lua| { + let globals = lua.globals(); + globals.set("filename", tmppath.to_str().unwrap()).unwrap(); + globals + .set("filename2", tmppath2.to_str().unwrap()) + .unwrap(); + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + let binchunk = globals.get::<_, BString>("binchunk").unwrap(); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + std::fs::write(&tmppath, binchunk).unwrap(); + std::fs::write(&tmppath2, "x = x + 3").unwrap(); + lua.load( + r#" + if _VERSION ~= "Lua 5.1" then + -- Lua 5.1 doesn't have the mode argument, so is + -- effectively always "bt". + chunk = loadfile(filename, "t") + assert(chunk == nil) + end + chunk = loadfile(filename, "bt") + assert(chunk ~= nil) + chunk() + chunk = loadfile(filename2) + chunk() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }) + }; +} + +#[test] +fn test_dofile_wrappers() { + let mut tmppath = std::env::temp_dir(); + tmppath.push("test_dofile_wrappers.lua"); + + Lua::new().context(|lua| { + let globals = lua.globals(); + globals.set("filename", tmppath.to_str().unwrap()).unwrap(); + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + let binchunk = globals.get::<_, BString>("binchunk").unwrap(); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + std::fs::write(&tmppath, binchunk).unwrap(); + lua.load( + r#" + ok, err = pcall(dofile, filename) + assert(not ok) + assert(err:match("rlua dofile: attempt to load bytecode")) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + std::fs::write(&tmppath, "x = x + 4").unwrap(); + lua.load( + r#" + ok, ret = pcall(dofile, filename) + assert(ok) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }); +} + +#[test] +fn test_no_dofile_wrappers() { + let mut tmppath = std::env::temp_dir(); + tmppath.push("test_no_dofile_wrappers.lua"); + + unsafe { + Lua::unsafe_new_with_flags( + StdLib::ALL_NO_DEBUG, + InitFlags::DEFAULT - InitFlags::LOAD_WRAPPERS, + ) + .context(|lua| { + let globals = lua.globals(); + globals.set("filename", tmppath.to_str().unwrap()).unwrap(); + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + let binchunk = globals.get::<_, BString>("binchunk").unwrap(); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + std::fs::write(&tmppath, binchunk).unwrap(); + lua.load( + r#" + ok, ret = pcall(dofile, filename) + assert(ok) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 2); + std::fs::write(&tmppath, "x = x + 4").unwrap(); + lua.load( + r#" + ok, ret = pcall(dofile, filename) + assert(ok) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 6); + }); + } +} + +#[test] +fn test_loadstring_wrappers() { + Lua::new().context(|lua| { + let globals = lua.globals(); + if globals.get::<_, Function>("loadstring").is_err() { + // Loadstring is not present in Lua 5.4, and only with a + // compatibility mode in Lua 5.3. + return; + } + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + lua.load( + r#" + assert(type(binchunk) == "string") + chunk = loadstring(binchunk) + assert(chunk == nil) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + lua.load( + r#" + local s = "x = x + 4" + chunk = loadstring(s) + assert(chunk ~= nil) + chunk() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }); +} + +#[test] +fn test_no_loadstring_wrappers() { + unsafe { + Lua::unsafe_new_with_flags( + StdLib::ALL_NO_DEBUG, + InitFlags::DEFAULT - InitFlags::LOAD_WRAPPERS, + ) + .context(|lua| { + let globals = lua.globals(); + if globals.get::<_, Function>("loadstring").is_err() { + // Loadstring is not present in Lua 5.4, and only with a + // compatibility mode in Lua 5.3. + return; + } + lua.load( + r#" + x = 0 + function incx() + x = x + 1 + end + binchunk = string.dump(incx) + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 0); + + lua.load( + r#" + incx() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 1); + lua.load( + r#" + assert(type(binchunk) == "string") + assert(binchunk:byte(1) == 27) + chunk = loadstring(binchunk) + assert(chunk ~= nil) + chunk() + chunk = loadstring("x = x + 3") + chunk() + "#, + ) + .exec() + .unwrap(); + assert_eq!(globals.get::<_, u32>("x").unwrap(), 5); + }) + }; +} + +#[test] +fn test_default_loadlib() { + Lua::new().context(|lua| { + let globals = lua.globals(); + let package = globals.get::<_, Table>("package").unwrap(); + let loadlib = package.get::<_, Function>("loadlib"); + assert!(loadlib.is_err()); + + lua.load( + r#" + assert(#(package.loaders or package.searchers) == 2) + "#, + ) + .exec() + .unwrap(); + }); +} + +#[test] +fn test_no_remove_loadlib() { + unsafe { + Lua::unsafe_new_with_flags( + StdLib::ALL_NO_DEBUG, + InitFlags::DEFAULT - InitFlags::REMOVE_LOADLIB, + ) + .context(|lua| { + let globals = lua.globals(); + let package = globals.get::<_, Table>("package").unwrap(); + let _loadlib = package.get::<_, Function>("loadlib").unwrap(); + + lua.load( + r#" + assert(#(package.loaders or package.searchers) == 4) + "#, + ) + .exec() + .unwrap(); + }); + } +} + #[test] fn test_result_conversions() { Lua::new().context(|lua| {