Skip to content

Commit

Permalink
feat: add an option to run a validation loop after each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
yjoer committed Oct 4, 2024
1 parent 79a9885 commit bf46acc
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 2 deletions.
17 changes: 17 additions & 0 deletions random/keras_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# %%
import os

os.environ["KERAS_BACKEND"] = "torch"

# %%
import keras

# %%
pbar = keras.utils.Progbar(target=10)

for i in range(10):
pbar.update(current=i + 1, values=[("loss", i + 1)])

pbar._values

# %%
42 changes: 40 additions & 2 deletions solutions/sports/numbers/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from camp.utils.torch_utils import save_checkpoint
from solutions.sports.numbers.resnet_pipeline import collate_fn
from solutions.sports.numbers.resnet_pipeline import transforms
from solutions.sports.numbers.resnet_pipeline import validation_loop

# %matplotlib inline
# %config InlineBackend.figure_formats = ['retina']
Expand All @@ -33,6 +34,8 @@

# %%
OVERFITTING_TEST = False
VALIDATION_SPLIT = False
VALIDATION_SPLIT_TEST = False

DATASET_PATH = "s3://datasets/soccernet_legibility"
CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50"
Expand All @@ -43,6 +46,12 @@
if OVERFITTING_TEST:
CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50_test"

if VALIDATION_SPLIT:
CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50_val"

if VALIDATION_SPLIT_TEST:
CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50_val_test"

# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on CPU.") if device.type == "cpu" else print("Running on GPU.")
Expand All @@ -64,11 +73,24 @@
transforms=transforms,
)

if hasattr(os, "register_at_fork") and hasattr(fsspec, "asyn"):
os.register_at_fork(after_in_child=fsspec.asyn.reset_lock)

if OVERFITTING_TEST:
train_dataset = Subset(train_dataset, indices=[0])

if hasattr(os, "register_at_fork") and hasattr(fsspec, "asyn"):
os.register_at_fork(after_in_child=fsspec.asyn.reset_lock)
if VALIDATION_SPLIT_TEST:
train_dataset = Subset(train_dataset, indices=list(range(100)))

n_images = len(train_dataset)

if VALIDATION_SPLIT:
split_point = int(0.8 * n_images)
train_idx = list(range(split_point))
val_idx = list(range(split_point, n_images))

val_dataset = Subset(train_dataset, indices=val_idx)
train_dataset = Subset(train_dataset, indices=train_idx)

# %%
if is_notebook():
Expand Down Expand Up @@ -99,6 +121,10 @@
n_epochs = 10
save_epochs = 10

if VALIDATION_SPLIT_TEST:
n_epochs = 5
save_epochs = 5

# %%
train_dataloader = DataLoader(
train_dataset,
Expand All @@ -109,6 +135,15 @@
persistent_workers=True,
)

if VALIDATION_SPLIT:
val_dataloader = DataLoader(
val_dataset,
batch_size,
num_workers=DATALOADER_WORKERS,
collate_fn=collate_fn,
persistent_workers=True,
)

# %%
params = [p for p in resnet.parameters() if p.requires_grad]
optimizer = optim.AdamW(params, lr=0.001, weight_decay=0.0001)
Expand Down Expand Up @@ -156,6 +191,9 @@
storage_options=storage_options,
)

if VALIDATION_SPLIT:
metrics_dict = validation_loop(resnet, val_dataloader, device)

epoch += 1

# %%
64 changes: 64 additions & 0 deletions solutions/sports/numbers/resnet_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import keras
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2.functional as tvf
from torch.utils.data import DataLoader
from torchmetrics import AUROC
from torchmetrics import Accuracy
from torchmetrics import F1Score
from torchmetrics import MetricCollection
from torchmetrics import Precision
from torchmetrics import Recall


def transforms(image, target):
Expand All @@ -19,3 +29,57 @@ def transforms(image, target):

def collate_fn(batch):
return tuple(zip(*batch))


def validation_loop(
resnet: nn.Module,
val_dataloader: DataLoader,
device: torch.device,
):
criterion = nn.BCEWithLogitsLoss()

metrics = MetricCollection(
[
Accuracy(task="binary"),
Precision(task="binary"),
Recall(task="binary"),
F1Score(task="binary"),
AUROC(task="binary"),
]
)

resnet.eval()

steps = 1
pbar = keras.utils.Progbar(len(val_dataloader))

with torch.no_grad():
for images, targets in val_dataloader:
images = torch.stack(images, dim=0).to(device)

targets = torch.stack(targets, dim=0)
targets_hot = F.one_hot(targets)
targets_hot = targets_hot.to(device, dtype=torch.float32)
targets = targets.to(device, dtype=torch.float32)

outputs = resnet(images)
predictions = torch.argmax(outputs, dim=1)
loss = criterion(outputs, targets_hot)

metrics_dict = metrics.forward(predictions, targets)

pbar.update(
current=steps,
values=[("loss", loss.item()), *list(metrics_dict.items())],
)

steps += 1

resnet.train()

means = {}
for k, v in pbar._values.items():
acc = v[0].item() if torch.is_tensor(v[0]) else v[0]
means[k] = acc / max(1, v[1])

return means

0 comments on commit bf46acc

Please sign in to comment.