Skip to content

Commit

Permalink
feat: Implement global namespace registration (#202)
Browse files Browse the repository at this point in the history
* feat: add more functionality to the script registry

* add global namespace alias

* register global functions in lua
  • Loading branch information
makspll authored Jan 13, 2025
1 parent 7921e19 commit c1b6375
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ pub enum Namespace {
OnType(TypeId),
}

/// A type which implements [`IntoNamespace`] by always converting to the global namespace
pub struct GlobalNamespace;

pub trait IntoNamespace {
fn into_namespace() -> Namespace;
}

impl<T: ?Sized + 'static> IntoNamespace for T {
fn into_namespace() -> Namespace {
Namespace::OnType(TypeId::of::<T>())
if TypeId::of::<T>() == TypeId::of::<GlobalNamespace>() {
Namespace::Global
} else {
Namespace::OnType(TypeId::of::<T>())
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ impl ScriptFunctionRegistryArc {

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct FunctionKey {
name: Cow<'static, str>,
namespace: Namespace,
pub name: Cow<'static, str>,
pub namespace: Namespace,
}

#[derive(Debug, Default)]
Expand All @@ -372,6 +372,8 @@ pub struct ScriptFunctionRegistry {
impl ScriptFunctionRegistry {
/// Register a script function with the given name. If the name already exists,
/// the new function will be registered as an overload of the function.
///
/// If you want to overwrite an existing function, use [`ScriptFunctionRegistry::overwrite`]
pub fn register<F, M>(
&mut self,
namespace: Namespace,
Expand All @@ -380,21 +382,44 @@ impl ScriptFunctionRegistry {
) where
F: ScriptFunction<'static, M>,
{
self.register_overload(namespace, name, func);
self.register_overload(namespace, name, func, false);
}

/// Overwrite a function with the given name. If the function does not exist, it will be registered as a new function.
pub fn overwrite<F, M>(
&mut self,
namespace: Namespace,
name: impl Into<Cow<'static, str>>,
func: F,
) where
F: ScriptFunction<'static, M>,
{
self.register_overload(namespace, name, func, true);
}

/// Remove a function from the registry if it exists. Returns the removed function if it was found.
pub fn remove(
&mut self,
namespace: Namespace,
name: impl Into<Cow<'static, str>>,
) -> Option<DynamicScriptFunction> {
let name = name.into();
self.functions.remove(&FunctionKey { name, namespace })
}

fn register_overload<F, M>(
&mut self,
namespace: Namespace,
name: impl Into<Cow<'static, str>>,
func: F,
overwrite: bool,
) where
F: ScriptFunction<'static, M>,
{
// always start with non-suffixed registration
// TODO: we do alot of string work, can we make this all more efficient?
let name: Cow<'static, str> = name.into();
if !self.contains(namespace, name.clone()) {
if overwrite || !self.contains(namespace, name.clone()) {
let func = func
.into_dynamic_script_function()
.with_name(name.clone())
Expand Down Expand Up @@ -636,4 +661,34 @@ mod test {
assert_eq!(all_functions[0].info.name(), "test");
assert_eq!(all_functions[1].info.name(), "test-1");
}

#[test]
fn test_overwrite_script_function() {
let mut registry = ScriptFunctionRegistry::default();
let fn_ = |a: usize, b: usize| a + b;
let namespace = Namespace::Global;
registry.register(namespace, "test", fn_);
let fn_2 = |a: usize, b: i32| a + (b as usize);
registry.overwrite(namespace, "test", fn_2);

let all_functions = registry
.iter_overloads(namespace, "test")
.expect("Failed to get overloads")
.collect::<Vec<_>>();

assert_eq!(all_functions.len(), 1);
assert_eq!(all_functions[0].info.name(), "test");
}

#[test]
fn test_remove_script_function() {
let mut registry = ScriptFunctionRegistry::default();
let fn_ = |a: usize, b: usize| a + b;
let namespace = Namespace::Global;
registry.register(namespace, "test", fn_);
let removed = registry.remove(namespace, "test");
assert!(removed.is_some());
let removed = registry.remove(namespace, "test");
assert!(removed.is_none());
}
}
5 changes: 4 additions & 1 deletion crates/bevy_mod_scripting_functions/src/test_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use bevy::{
use bevy_mod_scripting_core::{
bindings::{
function::{
namespace::NamespaceBuilder,
namespace::{GlobalNamespace, NamespaceBuilder},
script_function::{CallerContext, DynamicScriptFunctionMut},
},
pretty_print::DisplayWithWorld,
Expand Down Expand Up @@ -79,4 +79,7 @@ pub fn register_test_functions(world: &mut App) {
}
},
);

NamespaceBuilder::<GlobalNamespace>::new_unregistered(world)
.register("global_hello_world", || Ok("hi!"));
}
100 changes: 60 additions & 40 deletions crates/languages/bevy_mod_scripting_lua/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use bevy::{
app::{App, Plugin},
app::Plugin,
ecs::{entity::Entity, world::World},
};
use bevy_mod_scripting_core::{
asset::{AssetPathToLanguageMapper, Language},
bindings::{
script_value::ScriptValue, ThreadWorldContainer, WorldCallbackAccess, WorldContainer,
function::namespace::Namespace, script_value::ScriptValue, ThreadWorldContainer,
WorldCallbackAccess, WorldContainer,
},
context::{ContextBuilder, ContextInitializer, ContextPreHandlingInitializer},
error::ScriptError,
event::CallbackLabel,
reflection_extensions::PartialReflectExt,
script::ScriptId,
AddContextInitializer, IntoScriptPluginParams, ScriptingPlugin,
IntoScriptPluginParams, ScriptingPlugin,
};
use bindings::{
reference::{LuaReflectReference, LuaStaticReflectReference},
Expand Down Expand Up @@ -48,16 +49,62 @@ impl Default for LuaScriptingPlugin {
language_mapper: Some(AssetPathToLanguageMapper {
map: lua_language_mapper,
}),
context_initializers: vec![|_script_id, context| {
context
.globals()
.set(
"world",
LuaStaticReflectReference(std::any::TypeId::of::<World>()),
)
.map_err(ScriptError::from_mlua_error)?;
Ok(())
}],
context_initializers: vec![
|_script_id, context| {
// set the world global
context
.globals()
.set(
"world",
LuaStaticReflectReference(std::any::TypeId::of::<World>()),
)
.map_err(ScriptError::from_mlua_error)?;
Ok(())
},
|_script_id, context: &mut Lua| {
// set static globals
let world = ThreadWorldContainer.get_world();
let type_registry = world.type_registry();
let type_registry = type_registry.read();

for registration in type_registry.iter() {
// only do this for non generic types
// we don't want to see `Vec<Entity>:function()` in lua
if !registration.type_info().generics().is_empty() {
continue;
}

if let Some(global_name) =
registration.type_info().type_path_table().ident()
{
let ref_ = LuaStaticReflectReference(registration.type_id());
context
.globals()
.set(global_name, ref_)
.map_err(ScriptError::from_mlua_error)?;
}
}

// go through functions in the global namespace and add them to the lua context
let script_function_registry = world.script_function_registry();
let script_function_registry = script_function_registry.read();

for (key, function) in script_function_registry
.iter_all()
.filter(|(k, _)| k.namespace == Namespace::Global)
{
context
.globals()
.set(
key.name.to_string(),
LuaScriptValue::from(ScriptValue::Function(function.clone())),
)
.map_err(ScriptError::from_mlua_error)?;
}

Ok(())
},
],
context_pre_handling_initializers: vec![|script_id, entity, context| {
let world = ThreadWorldContainer.get_world();
context
Expand Down Expand Up @@ -89,33 +136,6 @@ impl Plugin for LuaScriptingPlugin {
fn build(&self, app: &mut bevy::prelude::App) {
self.scripting_plugin.build(app);
}

fn cleanup(&self, app: &mut App) {
// find all registered types, and insert dummy for calls

app.add_context_initializer::<LuaScriptingPlugin>(|_script_id, context: &mut Lua| {
let world = ThreadWorldContainer.get_world();
let type_registry = world.type_registry();
let type_registry = type_registry.read();

for registration in type_registry.iter() {
// only do this for non generic types
// we don't want to see `Vec<Entity>:function()` in lua
if !registration.type_info().generics().is_empty() {
continue;
}

if let Some(global_name) = registration.type_info().type_path_table().ident() {
let ref_ = LuaStaticReflectReference(registration.type_id());
context
.globals()
.set(global_name, ref_)
.map_err(ScriptError::from_mlua_error)?;
}
}
Ok(())
});
}
}

pub fn lua_context_load(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
assert(global_hello_world() == "hi!", "global_hello_world() == 'hi!'")

0 comments on commit c1b6375

Please sign in to comment.