-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Add Differentiable Physics: Mass-Spring System example #1332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7d3d2a1
39a0c8e
3d01b48
77c8abc
84072d5
8a7cf5e
f20fa20
e14251f
96e04f1
35b0afa
f1a806e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,154 @@ | ||||
import torch | ||||
import torch.nn as nn | ||||
import torch.optim as optim | ||||
import argparse | ||||
import matplotlib.pyplot as plt | ||||
import os | ||||
|
||||
|
||||
class MassSpringSystem(nn.Module): | ||||
def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"): | ||||
super().__init__() | ||||
self.device = device | ||||
self.mass = mass | ||||
self.springs = springs | ||||
self.dt = dt | ||||
self.gravity = gravity | ||||
|
||||
# Particle 0 is fixed at the origin | ||||
self.initial_position_0 = torch.tensor([0.0, 0.0], device=device) | ||||
|
||||
# Remaining particles are trainable | ||||
self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device)) | ||||
|
||||
# Velocities | ||||
self.velocities = torch.zeros(num_particles, 2, device=device) | ||||
|
||||
def forward(self, steps): | ||||
positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0) | ||||
velocities = self.velocities | ||||
|
||||
for _ in range(steps): | ||||
forces = torch.zeros_like(positions) | ||||
|
||||
# Compute spring forces | ||||
for (i, j, rest_length, stiffness) in self.springs: | ||||
xi, xj = positions[i], positions[j] | ||||
dir_vec = xj - xi | ||||
dist = dir_vec.norm() | ||||
force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6) | ||||
forces[i] += force | ||||
forces[j] -= force | ||||
|
||||
# Apply gravity | ||||
forces[:, 1] -= self.gravity * self.mass | ||||
|
||||
# Semi-implicit Euler integration | ||||
acceleration = forces / self.mass | ||||
velocities = velocities + acceleration * self.dt | ||||
positions = positions + velocities * self.dt | ||||
|
||||
# Fix particle 0 at origin | ||||
positions[0] = self.initial_position_0 | ||||
velocities[0] = torch.tensor([0.0, 0.0], device=positions.device) | ||||
|
||||
return positions | ||||
|
||||
|
||||
def visualize_positions(initial, final, target, save_path="mass_spring_viz.png"): | ||||
plt.figure(figsize=(6, 4)) | ||||
plt.scatter(initial[:, 0], initial[:, 1], c='blue', label='Initial', marker='x') | ||||
plt.scatter(final[:, 0], final[:, 1], c='green', label='Final', marker='o') | ||||
plt.scatter(target[:, 0], target[:, 1], c='red', label='Target', marker='*') | ||||
plt.title("Mass-Spring System Positions") | ||||
plt.xlabel("X") | ||||
plt.ylabel("Y") | ||||
plt.legend() | ||||
plt.grid(True) | ||||
plt.tight_layout() | ||||
plt.savefig(save_path) | ||||
print(f"Saved visualization to {os.path.abspath(save_path)}") | ||||
plt.close() | ||||
|
||||
|
||||
def train(args): | ||||
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu") | ||||
print(f"Using device: {device}") | ||||
system = MassSpringSystem( | ||||
num_particles=args.num_particles, | ||||
springs=[(0, 1, 1.0, args.stiffness)], | ||||
mass=args.mass, | ||||
dt=args.dt, | ||||
gravity=args.gravity, | ||||
device=device, | ||||
) | ||||
|
||||
optimizer = optim.Adam(system.parameters(), lr=args.lr) | ||||
target_positions = torch.tensor( | ||||
[[0.0, 0.0], [1.0, 0.0]], device=device | ||||
) | ||||
|
||||
for epoch in range(args.epochs): | ||||
optimizer.zero_grad() | ||||
final_positions = system(args.steps) | ||||
loss = (final_positions - target_positions).pow(2).mean() | ||||
loss.backward() | ||||
optimizer.step() | ||||
|
||||
if (epoch + 1) % args.log_interval == 0: | ||||
print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}") | ||||
|
||||
# Visualization | ||||
initial_positions = torch.cat([system.initial_position_0.unsqueeze(0), system.initial_positions_rest.detach()], dim=0).cpu().numpy() | ||||
visualize_positions(initial_positions, final_positions.detach().cpu().numpy(), target_positions.cpu().numpy()) | ||||
|
||||
print("\nTraining completed.") | ||||
print(f"Final positions:\n{final_positions.detach().cpu().numpy()}") | ||||
print(f"Target positions:\n{target_positions.cpu().numpy()}") | ||||
|
||||
|
||||
def evaluate(args): | ||||
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove:
Suggested change
|
||||
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu") | ||||
print(f"Using device: {device}") | ||||
system = MassSpringSystem( | ||||
num_particles=args.num_particles, | ||||
springs=[(0, 1, 1.0, args.stiffness)], | ||||
mass=args.mass, | ||||
dt=args.dt, | ||||
gravity=args.gravity, | ||||
device=device, | ||||
) | ||||
|
||||
with torch.no_grad(): | ||||
final_positions = system(args.steps) | ||||
print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}") | ||||
|
||||
|
||||
def parse_args(): | ||||
parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System") | ||||
parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs") | ||||
parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass") | ||||
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate") | ||||
parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration") | ||||
parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle") | ||||
parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant") | ||||
parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system") | ||||
parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval") | ||||
parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs") | ||||
parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength") | ||||
return parser.parse_args() | ||||
|
||||
|
||||
def main(): | ||||
args = parse_args() | ||||
|
||||
if args.mode == "train": | ||||
train(args) | ||||
elif args.mode == "eval": | ||||
evaluate(args) | ||||
|
||||
|
||||
if __name__ == "__main__": | ||||
main() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,43 @@ | ||||||
# Differentiable Physics: Mass-Spring System | ||||||
|
||||||
This example demonstrates a simple differentiable mass-spring system using PyTorch. | ||||||
|
||||||
Particles are connected by springs and evolve under the forces exerted by the springs and gravity. | ||||||
The system is fully differentiable, allowing the optimization of particle positions to match a target configuration using gradient-based learning. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I afraid I don't understand the task you are trying to solve here. Can it please, be thoroughly described and link to the associated paper provided? Looking into provided image:
|
||||||
|
||||||
--- | ||||||
|
||||||
## Files | ||||||
|
||||||
- `mass_spring.py` — Implements the mass-spring simulation, training loop, and evaluation. | ||||||
- `README.md` — Usage instructions and description. | ||||||
|
||||||
|
||||||
--- | ||||||
|
||||||
## Requirements | ||||||
|
||||||
- Python 3.8+ | ||||||
- PyTorch | ||||||
- pip install -r requirements.txt | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. matplotlib is missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add
Suggested change
to be consistent with other examples |
||||||
No external dependencies are required apart from PyTorch. | ||||||
|
||||||
--- | ||||||
|
||||||
## Usage | ||||||
|
||||||
First, ensure PyTorch is installed. | ||||||
|
||||||
#### Train the system | ||||||
|
||||||
```bash | ||||||
python mass_spring.py --mode train | ||||||
Comment on lines
+34
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code block is not properly closed. |
||||||
|
||||||
|
||||||
##### Visualization | ||||||
|
||||||
After training, the system's final positions are compared to the target positions. The plot below illustrates this comparison: | ||||||
|
||||||
 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not rendered. Maybe because above code block was not closed. |
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
torch>=2.6 | ||
matplotlib | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code as it's replaced by
torch.accelerator
: