Skip to content

Commit

Permalink
bevy_derive: Add derives for Deref and DerefMut (bevyengine#4328)
Browse files Browse the repository at this point in the history
# Objective

A common pattern in Rust is the [newtype](https://doc.rust-lang.org/rust-by-example/generics/new_types.html). This is an especially useful pattern in Bevy as it allows us to give common/foreign types different semantics (such as allowing it to implement `Component` or `FromWorld`) or to simply treat them as a "new type" (clever). For example, it allows us to wrap a common `Vec<String>` and do things like:

```rust
#[derive(Component)]
struct Items(Vec<String>);

fn give_sword(query: Query<&mut Items>) { 
  query.single_mut().0.push(String::from("Flaming Poisoning Raging Sword of Doom"));
}
```

> We could then define another struct that wraps `Vec<String>` without anything clashing in the query.

However, one of the worst parts of this pattern is the ugly `.0` we have to write in order to access the type we actually care about. This is why people often implement `Deref` and `DerefMut` in order to get around this.

Since it's such a common pattern, especially for Bevy, it makes sense to add a derive macro to automatically add those implementations.


## Solution

Added a derive macro for `Deref` and another for `DerefMut` (both exported into the prelude). This works on all structs (including tuple structs) as long as they only contain a single field:

```rust
#[derive(Deref)]
struct Foo(String);

#[derive(Deref, DerefMut)]
struct Bar {
  name: String,
}
```

This allows us to then remove that pesky `.0`:

```rust
#[derive(Component, Deref, DerefMut)]
struct Items(Vec<String>);

fn give_sword(query: Query<&mut Items>) { 
  query.single_mut().push(String::from("Flaming Poisoning Raging Sword of Doom"));
}
```

### Alternatives

There are other alternatives to this such as by using the [`derive_more`](https://crates.io/crates/derive_more) crate. However, it doesn't seem like we need an entire crate just yet since we only need `Deref` and `DerefMut` (for now).

### Considerations

One thing to consider is that the Rust std library recommends _not_ using `Deref` and `DerefMut` for things like this: "`Deref` should only be implemented for smart pointers to avoid confusion" ([reference](https://doc.rust-lang.org/std/ops/trait.Deref.html)). Personally, I believe it makes sense to use it in the way described above, but others may disagree.

### Additional Context

Discord: https://discord.com/channels/691052431525675048/692572690833473578/956648422163746827 (controversiality discussed [here](https://discord.com/channels/691052431525675048/692572690833473578/956711911481835630))

---

## Changelog

- Add `Deref` derive macro (exported to prelude)
- Add `DerefMut` derive macro (exported to prelude)
- Updated most newtypes in examples to use one or both derives

Co-authored-by: MrGVSV <49806985+MrGVSV@users.noreply.github.com>
  • Loading branch information
2 people authored and ItsDoot committed Feb 1, 2023
1 parent d0b738d commit ff71daa
Show file tree
Hide file tree
Showing 18 changed files with 174 additions and 37 deletions.
69 changes: 69 additions & 0 deletions crates/bevy_derive/src/derefs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use proc_macro::{Span, TokenStream};
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Index, Member, Type};

pub fn derive_deref(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

let ident = &ast.ident;
let (field_member, field_type) = match get_inner_field(&ast, false) {
Ok(items) => items,
Err(err) => {
return err.into_compile_error().into();
}
};
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();

TokenStream::from(quote! {
impl #impl_generics ::std::ops::Deref for #ident #ty_generics #where_clause {
type Target = #field_type;

fn deref(&self) -> &Self::Target {
&self.#field_member
}
}
})
}

pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

let ident = &ast.ident;
let (field_member, _) = match get_inner_field(&ast, true) {
Ok(items) => items,
Err(err) => {
return err.into_compile_error().into();
}
};
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();

TokenStream::from(quote! {
impl #impl_generics ::std::ops::DerefMut for #ident #ty_generics #where_clause {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.#field_member
}
}
})
}

fn get_inner_field(ast: &DeriveInput, is_mut: bool) -> syn::Result<(Member, &Type)> {
match &ast.data {
Data::Struct(data_struct) if data_struct.fields.len() == 1 => {
let field = data_struct.fields.iter().next().unwrap();
let member = field
.ident
.as_ref()
.map(|name| Member::Named(name.clone()))
.unwrap_or_else(|| Member::Unnamed(Index::from(0)));
Ok((member, &field.ty))
}
_ => {
let msg = if is_mut {
"DerefMut can only be derived for structs with a single field"
} else {
"Deref can only be derived for structs with a single field"
};
Err(syn::Error::new(Span::call_site().into(), msg))
}
}
}
56 changes: 56 additions & 0 deletions crates/bevy_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ extern crate proc_macro;

mod app_plugin;
mod bevy_main;
mod derefs;
mod enum_variant_meta;
mod modules;

Expand All @@ -15,6 +16,61 @@ pub fn derive_dynamic_plugin(input: TokenStream) -> TokenStream {
app_plugin::derive_dynamic_plugin(input)
}

/// Implements [`Deref`] for _single-item_ structs. This is especially useful when
/// utilizing the [newtype] pattern.
///
/// If you need [`DerefMut`] as well, consider using the other [derive] macro alongside
/// this one.
///
/// # Example
///
/// ```
/// use bevy_derive::Deref;
///
/// #[derive(Deref)]
/// struct MyNewtype(String);
///
/// let foo = MyNewtype(String::from("Hello"));
/// assert_eq!(5, foo.len());
/// ```
///
/// [`Deref`]: std::ops::Deref
/// [newtype]: https://doc.rust-lang.org/rust-by-example/generics/new_types.html
/// [`DerefMut`]: std::ops::DerefMut
/// [derive]: crate::derive_deref_mut
#[proc_macro_derive(Deref)]
pub fn derive_deref(input: TokenStream) -> TokenStream {
derefs::derive_deref(input)
}

/// Implements [`DerefMut`] for _single-item_ structs. This is especially useful when
/// utilizing the [newtype] pattern.
///
/// [`DerefMut`] requires a [`Deref`] implementation. You can implement it manually or use
/// Bevy's [derive] macro for convenience.
///
/// # Example
///
/// ```
/// use bevy_derive::{Deref, DerefMut};
///
/// #[derive(Deref, DerefMut)]
/// struct MyNewtype(String);
///
/// let mut foo = MyNewtype(String::from("Hello"));
/// foo.push_str(" World!");
/// assert_eq!("Hello World!", *foo);
/// ```
///
/// [`DerefMut`]: std::ops::DerefMut
/// [newtype]: https://doc.rust-lang.org/rust-by-example/generics/new_types.html
/// [`Deref`]: std::ops::Deref
/// [derive]: crate::derive_deref
#[proc_macro_derive(DerefMut)]
pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
derefs::derive_deref_mut(input)
}

#[proc_macro_attribute]
pub fn bevy_main(attr: TokenStream, item: TokenStream) -> TokenStream {
bevy_main::bevy_main(attr, item)
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_internal/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub use crate::{
transform::prelude::*, utils::prelude::*, window::prelude::*, DefaultPlugins, MinimalPlugins,
};

pub use bevy_derive::bevy_main;
pub use bevy_derive::{bevy_main, Deref, DerefMut};

#[doc(hidden)]
#[cfg(feature = "bevy_audio")]
Expand Down
3 changes: 2 additions & 1 deletion examples/2d/contributors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct ContributorSelection {
idx: usize,
}

#[derive(Deref, DerefMut)]
struct SelectTimer(Timer);

#[derive(Component)]
Expand Down Expand Up @@ -161,7 +162,7 @@ fn select_system(
mut query: Query<(&Contributor, &mut Sprite, &mut Transform)>,
time: Res<Time>,
) {
if !timer.0.tick(time.delta()).just_finished() {
if !timer.tick(time.delta()).just_finished() {
return;
}

Expand Down
5 changes: 3 additions & 2 deletions examples/2d/many_sprites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn move_camera(time: Res<Time>, mut camera_query: Query<&mut Transform, With<Cam
* Transform::from_translation(Vec3::X * CAMERA_SPEED * time.delta_seconds());
}

#[derive(Deref, DerefMut)]
struct PrintingTimer(Timer);

impl Default for PrintingTimer {
Expand All @@ -84,9 +85,9 @@ impl Default for PrintingTimer {

// System for printing the number of sprites on every tick of the timer
fn print_sprite_count(time: Res<Time>, mut timer: Local<PrintingTimer>, sprites: Query<&Sprite>) {
timer.0.tick(time.delta());
timer.tick(time.delta());

if timer.0.just_finished() {
if timer.just_finished() {
info!("Sprites: {}", sprites.iter().count(),);
}
}
6 changes: 3 additions & 3 deletions examples/2d/sprite_sheet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn main() {
.run();
}

#[derive(Component)]
#[derive(Component, Deref, DerefMut)]
struct AnimationTimer(Timer);

fn animate_sprite(
Expand All @@ -21,8 +21,8 @@ fn animate_sprite(
)>,
) {
for (mut timer, mut sprite, texture_atlas_handle) in query.iter_mut() {
timer.0.tick(time.delta());
if timer.0.just_finished() {
timer.tick(time.delta());
if timer.just_finished() {
let texture_atlas = texture_atlases.get(texture_atlas_handle).unwrap();
sprite.index = (sprite.index + 1) % texture_atlas.textures.len();
}
Expand Down
5 changes: 3 additions & 2 deletions examples/3d/many_cubes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ fn print_mesh_count(
mut timer: Local<PrintingTimer>,
sprites: Query<(&Handle<Mesh>, &ComputedVisibility)>,
) {
timer.0.tick(time.delta());
timer.tick(time.delta());

if timer.0.just_finished() {
if timer.just_finished() {
info!(
"Meshes: {} - Visible Meshes {}",
sprites.iter().len(),
Expand All @@ -156,6 +156,7 @@ fn print_mesh_count(
}
}

#[derive(Deref, DerefMut)]
struct PrintingTimer(Timer);

impl Default for PrintingTimer {
Expand Down
6 changes: 4 additions & 2 deletions examples/async_tasks/async_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ fn main() {
// Number of cubes to spawn across the x, y, and z axis
const NUM_CUBES: u32 = 6;

#[derive(Deref)]
struct BoxMeshHandle(Handle<Mesh>);
#[derive(Deref)]
struct BoxMaterialHandle(Handle<StandardMaterial>);

/// Startup system which runs only once and generates our Box Mesh
Expand Down Expand Up @@ -84,8 +86,8 @@ fn handle_tasks(
if let Some(transform) = future::block_on(future::poll_once(&mut *task)) {
// Add our new PbrBundle of components to our tagged entity
commands.entity(entity).insert_bundle(PbrBundle {
mesh: box_mesh_handle.0.clone(),
material: box_material_handle.0.clone(),
mesh: box_mesh_handle.clone(),
material: box_material_handle.clone(),
transform,
..default()
});
Expand Down
6 changes: 4 additions & 2 deletions examples/async_tasks/external_source_external_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ fn main() {
.run();
}

#[derive(Deref)]
struct StreamReceiver(Receiver<u32>);
struct StreamEvent(u32);

#[derive(Deref)]
struct LoadedFont(Handle<Font>);

fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {
Expand All @@ -43,7 +45,7 @@ fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {

// This system reads from the receiver and sends events to Bevy
fn read_stream(receiver: ResMut<StreamReceiver>, mut events: EventWriter<StreamEvent>) {
for from_stream in receiver.0.try_iter() {
for from_stream in receiver.try_iter() {
events.send(StreamEvent(from_stream));
}
}
Expand All @@ -54,7 +56,7 @@ fn spawn_text(
loaded_font: Res<LoadedFont>,
) {
let text_style = TextStyle {
font: loaded_font.0.clone(),
font: loaded_font.clone(),
font_size: 20.0,
color: Color::WHITE,
};
Expand Down
4 changes: 2 additions & 2 deletions examples/ecs/generic_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ enum AppState {
#[derive(Component)]
struct TextToPrint(String);

#[derive(Component)]
#[derive(Component, Deref, DerefMut)]
struct PrinterTick(bevy::prelude::Timer);

#[derive(Component)]
Expand Down Expand Up @@ -67,7 +67,7 @@ fn setup_system(mut commands: Commands) {

fn print_text_system(time: Res<Time>, mut query: Query<(&mut PrinterTick, &TextToPrint)>) {
for (mut timer, text) in query.iter_mut() {
if timer.0.tick(time.delta()).just_finished() {
if timer.tick(time.delta()).just_finished() {
info!("{}", text.0);
}
}
Expand Down
4 changes: 2 additions & 2 deletions examples/ecs/parallel_query.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bevy::{prelude::*, tasks::prelude::*};
use rand::random;

#[derive(Component)]
#[derive(Component, Deref)]
struct Velocity(Vec2);

fn spawn_system(mut commands: Commands, asset_server: Res<AssetServer>) {
Expand Down Expand Up @@ -31,7 +31,7 @@ fn move_system(pool: Res<ComputeTaskPool>, mut sprites: Query<(&mut Transform, &
// See the ParallelIterator documentation for more information on when
// to use or not use ParallelIterator over a normal Iterator.
sprites.par_for_each_mut(&pool, 32, |(mut transform, velocity)| {
transform.translation += velocity.0.extend(0.0);
transform.translation += velocity.extend(0.0);
});
}

Expand Down
3 changes: 2 additions & 1 deletion examples/ecs/system_chaining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ fn main() {
.run();
}

#[derive(Deref)]
struct Message(String);

// this system produces a Result<usize> output by trying to parse the Message resource
fn parse_message_system(message: Res<Message>) -> Result<usize> {
Ok(message.0.parse::<usize>()?)
Ok(message.parse::<usize>()?)
}

// This system takes a Result<usize> input and either prints the parsed value or the error message
Expand Down
4 changes: 2 additions & 2 deletions examples/ecs/timers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn main() {
.run();
}

#[derive(Component)]
#[derive(Component, Deref, DerefMut)]
pub struct PrintOnCompletionTimer(Timer);

pub struct Countdown {
Expand Down Expand Up @@ -44,7 +44,7 @@ fn setup(mut commands: Commands) {
/// using bevy's `Time` resource to get the delta between each update.
fn print_when_completed(time: Res<Time>, mut query: Query<&mut PrintOnCompletionTimer>) {
for mut timer in query.iter_mut() {
if timer.0.tick(time.delta()).just_finished() {
if timer.tick(time.delta()).just_finished() {
info!("Entity timer just finished");
}
}
Expand Down
18 changes: 9 additions & 9 deletions examples/game/breakout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct Paddle;
#[derive(Component)]
struct Ball;

#[derive(Component)]
#[derive(Component, Deref, DerefMut)]
struct Velocity(Vec2);

#[derive(Component)]
Expand Down Expand Up @@ -334,8 +334,8 @@ fn move_paddle(

fn apply_velocity(mut query: Query<(&mut Transform, &Velocity)>) {
for (mut transform, velocity) in query.iter_mut() {
transform.translation.x += velocity.0.x * TIME_STEP;
transform.translation.y += velocity.0.y * TIME_STEP;
transform.translation.x += velocity.x * TIME_STEP;
transform.translation.y += velocity.y * TIME_STEP;
}
}

Expand Down Expand Up @@ -375,21 +375,21 @@ fn check_for_collisions(
// only reflect if the ball's velocity is going in the opposite direction of the
// collision
match collision {
Collision::Left => reflect_x = ball_velocity.0.x > 0.0,
Collision::Right => reflect_x = ball_velocity.0.x < 0.0,
Collision::Top => reflect_y = ball_velocity.0.y < 0.0,
Collision::Bottom => reflect_y = ball_velocity.0.y > 0.0,
Collision::Left => reflect_x = ball_velocity.x > 0.0,
Collision::Right => reflect_x = ball_velocity.x < 0.0,
Collision::Top => reflect_y = ball_velocity.y < 0.0,
Collision::Bottom => reflect_y = ball_velocity.y > 0.0,
Collision::Inside => { /* do nothing */ }
}

// reflect velocity on the x-axis if we hit something on the x-axis
if reflect_x {
ball_velocity.0.x = -ball_velocity.0.x;
ball_velocity.x = -ball_velocity.x;
}

// reflect velocity on the y-axis if we hit something on the y-axis
if reflect_y {
ball_velocity.0.y = -ball_velocity.0.y;
ball_velocity.y = -ball_velocity.y;
}
}
}
Expand Down
Loading

0 comments on commit ff71daa

Please sign in to comment.