Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sampling normal distribution in expressions #362

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/firework.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn setup(mut commands: Commands, mut effects: ResMut<Assets<EffectAsset>>) {
let init_age = SetAttributeModifier::new(Attribute::AGE, age);

// Give a bit of variation by randomizing the lifetime per particle
let lifetime = writer.lit(0.8).uniform(writer.lit(1.2)).expr();
let lifetime = writer.lit(0.8).normal(writer.lit(1.2)).expr();
let init_lifetime = SetAttributeModifier::new(Attribute::LIFETIME, lifetime);

// Lifetime for trails
Expand Down
2 changes: 1 addition & 1 deletion examples/worms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn setup(
// scratch attribute.`
let set_initial_angle_modifier = SetAttributeModifier::new(
Attribute::F32_0,
writer.lit(0.0).uniform(writer.lit(PI * 2.0)).expr(),
writer.lit(0.0).normal(writer.lit(0.0)).expr(),
);

// Give each particle a random opaque color.
Expand Down
57 changes: 52 additions & 5 deletions src/graph/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ impl Module {
impl_module_binary!(step, Step);
impl_module_binary!(sub, Sub);
impl_module_binary!(uniform, UniformRand);
impl_module_binary!(normal, NormalRand);
impl_module_binary!(vec2, Vec2);

/// Build a ternary expression and append it to the module.
Expand Down Expand Up @@ -827,7 +828,7 @@ impl Expr {
Expr::Attribute(_) => false,
Expr::Unary { expr, .. } => module.has_side_effect(*expr),
Expr::Binary { left, right, op } => {
(*op == BinaryOperator::UniformRand)
(*op == BinaryOperator::UniformRand || *op == BinaryOperator::NormalRand)
|| module.has_side_effect(*left)
|| module.has_side_effect(*right)
}
Expand Down Expand Up @@ -1808,6 +1809,17 @@ pub enum BinaryOperator {
/// scalar type.
UniformRand,

/// Normal distribution random number operator.
///
/// Returns a value generated by a fast non-cryptographically-secure
/// pseudo-random number generator (PRNG) whose statistical characteristics
/// are undefined and generally focused around speed. The random value is
/// normally distributed with mean given by the first operand and standard
/// deviation by the second, which must be numeric types. If the operands
/// are vectors, they must be of the same rank, and the result is a vector
/// of that rank and same element scalar type.
NormalRand,

/// Constructor for 2-element vectors.
///
/// Given two scalar elements `x` and `y`, returns the vector consisting of
Expand Down Expand Up @@ -1840,18 +1852,23 @@ impl BinaryOperator {
| BinaryOperator::Min
| BinaryOperator::Step
| BinaryOperator::UniformRand
| BinaryOperator::NormalRand
| BinaryOperator::Vec2 => true,
}
}

/// Check if a binary operator needs a type suffix.
///
/// This is currently just for `rand_uniform`
/// (`BinaryOperator::UniformRand`), which is a function we define
/// ourselves. WGSL doesn't support user-defined function overloading, so
/// we need a suffix to disambiguate the types.
/// (`BinaryOperator::UniformRand`) and `rand_normal`
/// (`BinaryOperator::NormalRand`), which are functions we define ourselves.
/// WGSL doesn't support user-defined function overloading, so we need a
/// suffix to disambiguate the types.
pub fn needs_type_suffix(&self) -> bool {
*self == BinaryOperator::UniformRand
matches!(
*self,
BinaryOperator::UniformRand | BinaryOperator::NormalRand
)
}
}

Expand All @@ -1874,6 +1891,7 @@ impl ToWgslString for BinaryOperator {
BinaryOperator::Step => "step".to_string(),
BinaryOperator::Sub => "-".to_string(),
BinaryOperator::UniformRand => "rand_uniform".to_string(),
BinaryOperator::NormalRand => "rand_normal".to_string(),
BinaryOperator::Vec2 => "vec2".to_string(),
}
}
Expand Down Expand Up @@ -3283,6 +3301,35 @@ impl WriterExpr {
self.binary_op(other, BinaryOperator::UniformRand)
}

/// Apply the logical operator "normal" to this expression and another
/// expression.
///
/// This is a binary operator, which applies component-wise to vector
/// operand expressions. That is, for vectors, this produces a vector of
/// random values where each component is normally distributed with a mean
/// of the corresponding component of the first operand and a standard
/// deviation of the corresponding component of the second operand.
///
/// # Example
///
/// ```
/// # use bevy_hanabi::*;
/// # use bevy::math::Vec3;
/// # let mut w = ExprWriter::new();
/// // A literal expression `x = vec3<f32>(3., -2., 7.);`.
/// let x = w.lit(Vec3::new(3., -2., 7.));
///
/// // Another literal expression `y = vec3<f32>(1., 5., 7.);`.
/// let y = w.lit(Vec3::new(1., 5., 7.));
///
/// // A random variable normally distributed in [1:3]x[-2:5]x[7:7].
/// let z = x.normal(y);
/// ```
#[inline]
pub fn normal(self, other: Self) -> Self {
self.binary_op(other, BinaryOperator::NormalRand)
}

fn ternary_op(self, second: Self, third: Self, op: TernaryOperator) -> Self {
assert_eq!(self.module, second.module);
assert_eq!(self.module, third.module);
Expand Down
29 changes: 29 additions & 0 deletions src/render/vfx_common.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,35 @@ fn rand_uniform_vec4(a: vec4<f32>, b: vec4<f32>) -> vec4<f32> {
return a + frand4() * (b - a);
}

// Normal distribution computed using Box-Muller transform
fn rand_normal_f(mean: f32, std_dev: f32) -> f32 {
var u = frand();
var v = frand();
var r = sqrt(-2.0 * log(u));
return mean + std_dev * r * cos(tau * v);
}

fn rand_normal_vec2(mean: vec2f, std_dev: vec2f) -> vec2f {
var u = frand();
var v = frand2();
var r = sqrt(-2.0 * log(u));
return mean + std_dev * r * cos(tau * v);
}

fn rand_normal_vec3(mean: vec3f, std_dev: vec3f) -> vec3f {
var u = frand();
var v = frand3();
var r = sqrt(-2.0 * log(u));
return mean + std_dev * r * cos(tau * v);
}

fn rand_normal_vec4(mean: vec4f, std_dev: vec4f) -> vec4f {
var u = frand();
var v = frand4();
var r = sqrt(-2.0 * log(u));
return mean + std_dev * r * cos(tau * v);
}

fn proj(u: vec3<f32>, v: vec3<f32>) -> vec3<f32> {
return dot(v, u) / dot(u,u) * u;
}
3 changes: 2 additions & 1 deletion src/render/vfx_init.wgsl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#import bevy_hanabi::vfx_common::{
IndirectBuffer, ParticleGroup, RenderEffectMetadata, RenderGroupIndirect, SimParams, Spawner,
seed, tau, pcg_hash, to_float01, frand, frand2, frand3, frand4,
rand_uniform_f, rand_uniform_vec2, rand_uniform_vec3, rand_uniform_vec4, proj
rand_uniform_f, rand_uniform_vec2, rand_uniform_vec3, rand_uniform_vec4,
rand_normal_f, rand_normal_vec2, rand_normal_vec3, rand_normal_vec4, proj
}

struct Particle {
Expand Down
3 changes: 2 additions & 1 deletion src/render/vfx_render.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#import bevy_hanabi::vfx_common::{
DispatchIndirect, IndirectBuffer, SimParams, Spawner,
seed, tau, pcg_hash, to_float01, frand, frand2, frand3, frand4,
rand_uniform_f, rand_uniform_vec2, rand_uniform_vec3, rand_uniform_vec4, proj
rand_uniform_f, rand_uniform_vec2, rand_uniform_vec3, rand_uniform_vec4,
rand_normal_f, rand_normal_vec2, rand_normal_vec3, rand_normal_vec4, proj
}

struct Particle {
Expand Down
3 changes: 2 additions & 1 deletion src/render/vfx_update.wgsl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#import bevy_hanabi::vfx_common::{
IndirectBuffer, ParticleGroup, RenderEffectMetadata, RenderGroupIndirect, SimParams, Spawner,
seed, tau, pcg_hash, to_float01, frand, frand2, frand3, frand4,
rand_uniform_f, rand_uniform_vec2, rand_uniform_vec3, rand_uniform_vec4, proj
rand_uniform_f, rand_uniform_vec2, rand_uniform_vec3, rand_uniform_vec4,
rand_normal_f, rand_normal_vec2, rand_normal_vec3, rand_normal_vec4, proj
}

struct Particle {
Expand Down
Loading