Skip to content

Commit

Permalink
neatened
Browse files Browse the repository at this point in the history
  • Loading branch information
aidanscannell committed Feb 9, 2024
1 parent 156fab3 commit db7f6c7
Showing 1 changed file with 39 additions and 53 deletions.
92 changes: 39 additions & 53 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@
logger = logging.getLogger(__name__)


class MetricLogger:
def __init__(self):
self.df = pd.DataFrame(columns=["Model", "loss", "acc", "nlpd", "ece"])

def log(self, metrics: dict, name: str):
logger.info(
f"{name} NLPD {metrics['nlpd']} | ACC: {metrics['acc']} | ECE: {metrics['ece']}"
)
metrics.update({"Model": name})
if wandb.run is not None:
self.df.loc[len(self.df.index)] = metrics
wandb.log({"Metrics": wandb.Table(data=self.df)})


@dataclass
class TrainConfig:
# Dataset
Expand Down Expand Up @@ -69,26 +83,25 @@ class TrainConfig:

@hydra.main(version_base="1.3", config_path="./cfgs", config_name="train")
def train(cfg: TrainConfig):
# Make experiment reproducible
##### Make experiment reproducible #####
torch.cuda.manual_seed(cfg.seed)
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
# random.seed(cfg.seed)
eval('setattr(torch.backends.cudnn, "determinstic", True)')
eval('setattr(torch.backends.cudnn, "benchmark", False)')

# Use GPU if requested and available
##### Use GPU if requested and available #####
if "cuda" in cfg.device:
if torch.cuda.is_available():
cfg.device = "cuda"
else:
logger.info("CUDA requested but not available")
cfg.device = "cpu"
logger.info("Using device: {}".format(cfg.device))
logger.info(f"Using device: {cfg.device}")

# Initialize W&B
##### Initialize W&B #####
if cfg.use_wandb:
run = wandb.init(
wandb.init(
project=cfg.wandb_project_name,
name=cfg.wandb_run_name,
group=cfg.dataset,
Expand All @@ -98,10 +111,9 @@ def train(cfg: TrainConfig):
),
dir=get_original_cwd(), # don't nest wandb inside hydra dir
)
save_dir = run.dir if cfg.use_wandb else "./"
ckpt_path = os.path.join(save_dir, "best_ckpt.pt")
ckpt_path = os.path.join(get_original_cwd(), "best_ckpt.pt")

# Load the data with train/val/test split
##### Load data with train/val/test split #####
save_dir = f"{get_original_cwd()}/data"
if "FMNIST" in cfg.dataset:
dataset_fn = torchvision.datasets.FashionMNIST
Expand Down Expand Up @@ -151,9 +163,9 @@ def train(cfg: TrainConfig):
ds_test, batch_size=cfg.test_batch_size, shuffle=True, pin_memory=True
)

# Instantiate SFR
##### Instantiate SFR #####
# TODO This doesn't use tanh...
network = utils.CIFAR10Net(in_channels=in_channels, n_out=output_dim, use_tanh=True)
network = h.CIFAR10Net(in_channels=in_channels, n_out=output_dim, use_tanh=True)
prior = priors.Gaussian(
params=network.parameters, prior_precision=cfg.prior_precision
)
Expand All @@ -170,7 +182,7 @@ def train(cfg: TrainConfig):
)
optimizer = torch.optim.Adam([{"params": model.parameters()}], lr=cfg.lr)

early_stopper = utils.EarlyStopper(
early_stopper = h.EarlyStopper(
patience=int(cfg.early_stop_patience / cfg.logging_epoch_freq),
min_delta=cfg.early_stop_min_delta,
)
Expand Down Expand Up @@ -201,7 +213,7 @@ def evaluate(model: sfr.SFR, data_loader: DataLoader, sfr_pred: bool = False):
model.train()
return metrics

# Train NN weights with empirical regularized risk
##### Train NN weights with empirical regularized risk #####
best_loss = float("inf")
for epoch_idx in tqdm(list(range(cfg.n_epochs)), total=cfg.n_epochs):
with tqdm(train_loader, unit="batch") as tepoch:
Expand Down Expand Up @@ -230,51 +242,29 @@ def evaluate(model: sfr.SFR, data_loader: DataLoader, sfr_pred: bool = False):
logger.info("Early stopping criteria met, stopping training...")
break

# Load checkpoint
ckpt = torch.load(ckpt_path)
print(f"ckpt {ckpt}")
print(f"sfr {[p for p in model.parameters()]}")
model.load_state_dict(ckpt["model"])

logger.info("Finished training")

class MetricLogger:
def __init__(self):
self.df = pd.DataFrame(columns=["Model", "loss", "acc", "nlpd", "ece"])

def log(self, metrics: dict, name: str):
logger.info(
f"{name} NLPD {metrics['nlpd']} | ACC: {metrics['acc']} | ECE: {metrics['ece']}"
)
metrics.update({"Model": name})
if wandb.run is not None:
self.df.loc[len(self.df.index)] = metrics
wandb.log({"Metrics": wandb.Table(data=self.df)})

# Calculate NN's metrics and log
nn_metrics = evaluate(model, data_loader=test_loader, sfr_pred=False)
metric_logger = MetricLogger()
metric_logger.log(nn_metrics, name="NN")

if cfg.debug:
nn_metrics = evaluate(model, data_loader=train_loader, sfr_pred=False)
metric_logger.log(nn_metrics, name="NN-train")
##### Load the best checkpoint (with lowest validation loss) #####
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["model"])

# Calculate posterior (dual parameters etc)
##### Fit SFR posterior (dual parameters etc) #####
logger.info("Fitting SFR...")
model.fit(train_loader=train_loader)
logger.info("Finished fitting SFR")

if cfg.debug:
sfr_metrics = evaluate(model, data_loader=train_loader, sfr_pred=True)
metric_logger.log(sfr_metrics, name="SFR-train")
##### Log metrics on test set #####
metric_logger = MetricLogger()

# Calculate SFR's metrics and log
# Calculate metrics for NN MAP
nn_metrics = evaluate(model, data_loader=test_loader, sfr_pred=False)
metric_logger.log(nn_metrics, name="NN")

# Calculate metrics for SFR
sfr_metrics = evaluate(model, data_loader=test_loader, sfr_pred=True)
metric_logger.log(sfr_metrics, name="SFR")
nn_metrics = evaluate(model, data_loader=test_loader, sfr_pred=False)
metric_logger.log(nn_metrics, name="NN double")

##### (Optionally) Optimize the prior precision posthoc using BO #####
if cfg.optimize_prior_prec:
model.optimize_prior_precision(
pred_type="gp",
Expand All @@ -286,13 +276,9 @@ def log(self, metrics: dict, name: str):
num_trials=20,
)

if cfg.debug:
sfr_metrics = evaluate(model, data_loader=train_loader, sfr_pred=True)
metric_logger.log(sfr_metrics, name="SFR-train-bo")

# Calculate SFR's metrics and log
##### Calculate metrics for SFR with prior precision tuned #####
sfr_metrics = evaluate(model, data_loader=test_loader, sfr_pred=True)
metric_logger.log(sfr_metrics, name="SFR-bo")
metric_logger.log(sfr_metrics, name="SFR (δ-tuned)")


if __name__ == "__main__":
Expand Down

0 comments on commit db7f6c7

Please sign in to comment.