diff --git a/differentiable_physics/mass_spring.py b/differentiable_physics/mass_spring.py new file mode 100644 index 0000000000..e7d83d487f --- /dev/null +++ b/differentiable_physics/mass_spring.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import argparse +import matplotlib.pyplot as plt +import os + + +class MassSpringSystem(nn.Module): + def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"): + super().__init__() + self.device = device + self.mass = mass + self.springs = springs + self.dt = dt + self.gravity = gravity + + # Particle 0 is fixed at the origin + self.initial_position_0 = torch.tensor([0.0, 0.0], device=device) + + # Remaining particles are trainable + self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device)) + + # Velocities + self.velocities = torch.zeros(num_particles, 2, device=device) + + def forward(self, steps): + positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0) + velocities = self.velocities + + for _ in range(steps): + forces = torch.zeros_like(positions) + + # Compute spring forces + for (i, j, rest_length, stiffness) in self.springs: + xi, xj = positions[i], positions[j] + dir_vec = xj - xi + dist = dir_vec.norm() + force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6) + forces[i] += force + forces[j] -= force + + # Apply gravity + forces[:, 1] -= self.gravity * self.mass + + # Semi-implicit Euler integration + acceleration = forces / self.mass + velocities = velocities + acceleration * self.dt + positions = positions + velocities * self.dt + + # Fix particle 0 at origin + positions[0] = self.initial_position_0 + velocities[0] = torch.tensor([0.0, 0.0], device=positions.device) + + return positions + + +def visualize_positions(initial, final, target, save_path="mass_spring_viz.png"): + plt.figure(figsize=(6, 4)) + plt.scatter(initial[:, 0], initial[:, 1], c='blue', label='Initial', marker='x') + plt.scatter(final[:, 0], final[:, 1], c='green', label='Final', marker='o') + plt.scatter(target[:, 0], target[:, 1], c='red', label='Target', marker='*') + plt.title("Mass-Spring System Positions") + plt.xlabel("X") + plt.ylabel("Y") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig(save_path) + print(f"Saved visualization to {os.path.abspath(save_path)}") + plt.close() + + +def train(args): + #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu") + print(f"Using device: {device}") + system = MassSpringSystem( + num_particles=args.num_particles, + springs=[(0, 1, 1.0, args.stiffness)], + mass=args.mass, + dt=args.dt, + gravity=args.gravity, + device=device, + ) + + optimizer = optim.Adam(system.parameters(), lr=args.lr) + target_positions = torch.tensor( + [[0.0, 0.0], [1.0, 0.0]], device=device + ) + + for epoch in range(args.epochs): + optimizer.zero_grad() + final_positions = system(args.steps) + loss = (final_positions - target_positions).pow(2).mean() + loss.backward() + optimizer.step() + + if (epoch + 1) % args.log_interval == 0: + print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}") + + # Visualization + initial_positions = torch.cat([system.initial_position_0.unsqueeze(0), system.initial_positions_rest.detach()], dim=0).cpu().numpy() + visualize_positions(initial_positions, final_positions.detach().cpu().numpy(), target_positions.cpu().numpy()) + + print("\nTraining completed.") + print(f"Final positions:\n{final_positions.detach().cpu().numpy()}") + print(f"Target positions:\n{target_positions.cpu().numpy()}") + + +def evaluate(args): + #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu") + print(f"Using device: {device}") + system = MassSpringSystem( + num_particles=args.num_particles, + springs=[(0, 1, 1.0, args.stiffness)], + mass=args.mass, + dt=args.dt, + gravity=args.gravity, + device=device, + ) + + with torch.no_grad(): + final_positions = system(args.steps) + print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System") + parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs") + parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass") + parser.add_argument("--lr", type=float, default=0.01, help="Learning rate") + parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration") + parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle") + parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant") + parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system") + parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval") + parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs") + parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength") + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.mode == "train": + train(args) + elif args.mode == "eval": + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/differentiable_physics/mass_spring_viz.png b/differentiable_physics/mass_spring_viz.png new file mode 100644 index 0000000000..0b3a08509c Binary files /dev/null and b/differentiable_physics/mass_spring_viz.png differ diff --git a/differentiable_physics/readme.md b/differentiable_physics/readme.md new file mode 100644 index 0000000000..6c0f60512b --- /dev/null +++ b/differentiable_physics/readme.md @@ -0,0 +1,43 @@ +# Differentiable Physics: Mass-Spring System + +This example demonstrates a simple differentiable mass-spring system using PyTorch. + +Particles are connected by springs and evolve under the forces exerted by the springs and gravity. +The system is fully differentiable, allowing the optimization of particle positions to match a target configuration using gradient-based learning. + +--- + +## Files + +- `mass_spring.py` — Implements the mass-spring simulation, training loop, and evaluation. +- `README.md` — Usage instructions and description. + + +--- + +## Requirements + +- Python 3.8+ +- PyTorch +- pip install -r requirements.txt + +No external dependencies are required apart from PyTorch. + +--- + +## Usage + +First, ensure PyTorch is installed. + +#### Train the system + +```bash +python mass_spring.py --mode train + + +##### Visualization + +After training, the system's final positions are compared to the target positions. The plot below illustrates this comparison: + +![Mass-Spring System Visualization](mass_spring_viz.png) + diff --git a/differentiable_physics/requirements.txt b/differentiable_physics/requirements.txt new file mode 100644 index 0000000000..f88df426f8 --- /dev/null +++ b/differentiable_physics/requirements.txt @@ -0,0 +1,3 @@ +torch>=2.6 +matplotlib + diff --git a/run_python_examples.sh b/run_python_examples.sh index e075a28ed2..d5c34eb67d 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -164,6 +164,11 @@ function gat() { uv run main.py --epochs 1 --dry-run || error "graph attention network failed" } +function differentiable_physics() { + uv run mass_spring.py --mode train --epochs 5 --steps 3 || error "differentiable_physics example failed" +} + + eval "base_$(declare -f stop)" function stop() { @@ -217,6 +222,7 @@ function run_all() { run fx run gcn run gat + run differentiable_physics } # by default, run all examples