From c132159a5fc65fd253f3daa1d0db9c3239941a1d Mon Sep 17 00:00:00 2001 From: ngastzepeda Date: Thu, 25 Jan 2024 21:17:55 +0100 Subject: [PATCH] Define device early to make run on Mac possible. Also, include timestamp in run name --- notebooks/cvrptw/run.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/notebooks/cvrptw/run.py b/notebooks/cvrptw/run.py index 67b32502..6967ceb7 100644 --- a/notebooks/cvrptw/run.py +++ b/notebooks/cvrptw/run.py @@ -6,6 +6,15 @@ from rl4co.models.zoo.am import AttentionModel from rl4co.utils.trainer import RL4COTrainer +device_str = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if (torch.backends.mps.is_available() and torch.backends.mps.is_built()) + else "cpu" +) +device = torch.device(device_str) + env_cvrptw = CVRPTWEnv( num_loc=30, min_loc=0, @@ -17,6 +26,7 @@ min_time=0, max_time=480, scale=True, + device=device_str, ) env = env_cvrptw @@ -34,7 +44,7 @@ ### --- random policy --- ### reward, td, actions = rollout( env=env, - td=env.reset(batch_size=[batch_size]), + td=env.reset(batch_size=[batch_size]).to(device), policy=random_policy, max_steps=1000, ) @@ -43,7 +53,7 @@ env.get_reward(td, actions) CVRPTWEnv.check_solution_validity(td, actions) - env.render(td, actions) + # env.render(td, actions) ### --- AM --- ### # Model: default is AM with REINFORCE and greedy rollout baseline @@ -55,17 +65,22 @@ ) # Greedy rollouts over untrained model - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") td_init = env.reset(batch_size=[3]).to(device) model = model.to(device) out = model(td_init.clone(), phase="test", decode_type="greedy", return_actions=True) ### --- Logging --- ### + from datetime import date, datetime import wandb from lightning.pytorch.loggers import WandbLogger + date_time_str = datetime.now().strftime("%Y/%m/%d_%H:%M:%S") + wandb.login() - logger = WandbLogger(project="routefinder", name="cvrptw-am") + logger = WandbLogger( + project="routefinder", + name=f"cvrptw-am_{date_time_str}", + ) ### --- Training --- ### # The RL4CO trainer is a wrapper around PyTorch Lightning's `Trainer` class which adds some functionality and more efficient defaults @@ -81,3 +96,6 @@ ### --- Testing --- ### trainer.test(model) + + ### --- Saving --- ### + trainer.save_checkpoint(f"data/cvrptw_{date_time_str}.ckpt")