Skip to content

Commit

Permalink
Define device early to make run on Mac possible. Also, include timest…
Browse files Browse the repository at this point in the history
…amp in run name
  • Loading branch information
ngastzepeda committed Jan 25, 2024
1 parent 3ce2a35 commit c132159
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions notebooks/cvrptw/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from rl4co.models.zoo.am import AttentionModel
from rl4co.utils.trainer import RL4COTrainer

device_str = (
"cuda"
if torch.cuda.is_available()
else "mps"
if (torch.backends.mps.is_available() and torch.backends.mps.is_built())
else "cpu"
)
device = torch.device(device_str)

env_cvrptw = CVRPTWEnv(
num_loc=30,
min_loc=0,
Expand All @@ -17,6 +26,7 @@
min_time=0,
max_time=480,
scale=True,
device=device_str,
)

env = env_cvrptw
Expand All @@ -34,7 +44,7 @@
### --- random policy --- ###
reward, td, actions = rollout(
env=env,
td=env.reset(batch_size=[batch_size]),
td=env.reset(batch_size=[batch_size]).to(device),
policy=random_policy,
max_steps=1000,
)
Expand All @@ -43,7 +53,7 @@
env.get_reward(td, actions)
CVRPTWEnv.check_solution_validity(td, actions)

env.render(td, actions)
# env.render(td, actions)

### --- AM --- ###
# Model: default is AM with REINFORCE and greedy rollout baseline
Expand All @@ -55,17 +65,22 @@
)

# Greedy rollouts over untrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
model = model.to(device)
out = model(td_init.clone(), phase="test", decode_type="greedy", return_actions=True)

### --- Logging --- ###
from datetime import date, datetime
import wandb
from lightning.pytorch.loggers import WandbLogger

date_time_str = datetime.now().strftime("%Y/%m/%d_%H:%M:%S")

wandb.login()
logger = WandbLogger(project="routefinder", name="cvrptw-am")
logger = WandbLogger(
project="routefinder",
name=f"cvrptw-am_{date_time_str}",
)

### --- Training --- ###
# The RL4CO trainer is a wrapper around PyTorch Lightning's `Trainer` class which adds some functionality and more efficient defaults
Expand All @@ -81,3 +96,6 @@

### --- Testing --- ###
trainer.test(model)

### --- Saving --- ###
trainer.save_checkpoint(f"data/cvrptw_{date_time_str}.ckpt")

0 comments on commit c132159

Please sign in to comment.