Skip to content

Commit

Permalink
Merge branch 'nbody_example' into 'main'
Browse files Browse the repository at this point in the history
Adding tiled nbody example

See merge request omniverse/warp!974
  • Loading branch information
daedalus5 committed Jan 17, 2025
2 parents 1336132 + 146791b commit b7c9d2a
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Add per-module option to add CUDA-C line information for profiling, use `wp.set_module_options({"lineinfo": True})`.
- Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication.
- Add operator overloads for `wp.struct` objects by defining `wp.func` functions ([GH-392](https://github.com/NVIDIA/warp/issues/392)).
- Add `example_tile_nbody.py`, an N-Body gravitational simulation example using Warp tile primitives.

### Changed

Expand Down
180 changes: 180 additions & 0 deletions warp/examples/tile/example_tile_nbody.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

###########################################################################
# Example N-Body
#
# Shows how to simulate an N-Body gravitational problem using an all-pairs
# approach with Warp tile primitives.
#
# References:
# L. Nyland, M. Harris, and J. Prins. "Fast N-Body Simulation with
# CUDA" in GPU Gems 3. H. Nguyen, Addison-Wesley Professional, 2007.
# https://developer.nvidia.com/gpugems/gpugems3/part-v-physics-simulation/chapter-31-fast-n-body-simulation-cuda
#
###########################################################################

import argparse

import numpy as np

import warp as wp

wp.init()

DT = wp.constant(0.01)
SOFTENING_SQ = wp.constant(0.1**2) # Softening factor for numerical stability
TILE_SIZE = wp.constant(64)
PARTICLE_MASS = wp.constant(1.0)


@wp.func
def body_body_interaction(p0: wp.vec3, pi: wp.vec3):
"""Return the acceleration of the particle at position `p0` due to the
particle at position `pi`."""
r = pi - p0

dist_sq = wp.length_sq(r) + SOFTENING_SQ

inv_dist = 1.0 / wp.sqrt(dist_sq)
inv_dist_cubed = inv_dist * inv_dist * inv_dist

acc = PARTICLE_MASS * inv_dist_cubed * r

return acc


@wp.kernel
def integrate_bodies_tiled(
old_position: wp.array(dtype=wp.vec3),
velocity: wp.array(dtype=wp.vec3),
new_position: wp.array(dtype=wp.vec3),
num_bodies: int,
):
i = wp.tid()

p0 = old_position[i]

accel = wp.vec3(0.0, 0.0, 0.0)

for k in range(num_bodies / TILE_SIZE):
k_tile = wp.tile_load(old_position, k, TILE_SIZE, storage="shared")
for idx in range(TILE_SIZE):
pi = k_tile[0, idx]
accel += body_body_interaction(p0, pi)

# Advance the velocity one timestep (in-place)
velocity[i] = velocity[i] + accel * DT

# Advance the positions (using a second array)
new_position[i] = old_position[i] + DT * velocity[i]


class Example:
def __init__(self, headless=False, num_bodies=1024):
self.num_bodies = num_bodies

rng = np.random.default_rng(42)

# Sample the surface of a sphere
r = 10.0 * (num_bodies / 1024) ** (1 / 2) # Scale factor to maintain a constant density
phi = np.arccos(1.0 - 2.0 * rng.uniform(size=self.num_bodies))
theta = rng.uniform(low=0.0, high=2.0 * np.pi, size=self.num_bodies)
x = r * np.cos(theta) * np.sin(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(phi)

self.scale = r
init_pos_np = np.stack((x, y, z), axis=1)

self.pos_array_0 = wp.array(init_pos_np, dtype=wp.vec3)
self.pos_array_1 = wp.empty_like(self.pos_array_0)
self.vel_array = wp.zeros(self.num_bodies, dtype=wp.vec3)

if headless:
self.scatter_plot = None
else:
self.scatter_plot = self.create_plot()

def create_plot(self):
import matplotlib.pyplot as plt

# Create a figure and a 3D axis for the plot
self.fig = plt.figure()
ax = self.fig.add_subplot(111, projection="3d")

# Scatter plot of initial positions
init_pos_np = self.pos_array_0.numpy()
scatter_plot = ax.scatter(init_pos_np[:, 0], init_pos_np[:, 1], init_pos_np[:, 2], c="#76b900", alpha=0.5)

# Set axis limits
ax.set_xlim(-self.scale, self.scale)
ax.set_ylim(-self.scale, self.scale)
ax.set_zlim(-self.scale, self.scale)

return scatter_plot

def step(self):
wp.launch(
integrate_bodies_tiled,
dim=self.num_bodies,
inputs=[self.pos_array_0, self.vel_array, self.pos_array_1, self.num_bodies],
block_dim=TILE_SIZE,
)

# Swap arrays
(self.pos_array_0, self.pos_array_1) = (self.pos_array_1, self.pos_array_0)

def render(self):
positions_cpu = self.pos_array_0.numpy()

# Update scatter plot positions
self.scatter_plot._offsets3d = (
positions_cpu[:, 0],
positions_cpu[:, 1],
positions_cpu[:, 2],
)

# Function to update the scatter plot
def step_and_render(self, frame):
self.step()
self.render()


if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
parser.add_argument("--num_frames", type=int, default=1000, help="Total number of frames.")
parser.add_argument("-N", help="Number of bodies. Should be a multiple of 64.", type=int, default=1024)
parser.add_argument(
"--headless",
action="store_true",
help="Run in headless mode, suppressing the opening of any graphical windows.",
)

args = parser.parse_known_args()[0]

if args.device == "cpu":
print("This example only runs on CUDA devices.")
exit()

with wp.ScopedDevice(args.device):
example = Example(headless=args.headless, num_bodies=args.N)

if not args.headless:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Create the animation
ani = FuncAnimation(example.fig, example.step_and_render, frames=args.num_frames, interval=50, repeat=False)

# Display the animation
plt.show()

else:
for _ in range(args.num_frames):
example.step()

0 comments on commit b7c9d2a

Please sign in to comment.