diff --git a/random/keras_utils.py b/random/keras_utils.py new file mode 100644 index 0000000..9ae6a62 --- /dev/null +++ b/random/keras_utils.py @@ -0,0 +1,17 @@ +# %% +import os + +os.environ["KERAS_BACKEND"] = "torch" + +# %% +import keras + +# %% +pbar = keras.utils.Progbar(target=10) + +for i in range(10): + pbar.update(current=i + 1, values=[("loss", i + 1)]) + +pbar._values + +# %% diff --git a/solutions/sports/numbers/resnet.py b/solutions/sports/numbers/resnet.py index 9204528..4bccb05 100644 --- a/solutions/sports/numbers/resnet.py +++ b/solutions/sports/numbers/resnet.py @@ -24,6 +24,7 @@ from camp.utils.torch_utils import save_checkpoint from solutions.sports.numbers.resnet_pipeline import collate_fn from solutions.sports.numbers.resnet_pipeline import transforms +from solutions.sports.numbers.resnet_pipeline import validation_loop # %matplotlib inline # %config InlineBackend.figure_formats = ['retina'] @@ -33,6 +34,8 @@ # %% OVERFITTING_TEST = False +VALIDATION_SPLIT = False +VALIDATION_SPLIT_TEST = False DATASET_PATH = "s3://datasets/soccernet_legibility" CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50" @@ -43,6 +46,12 @@ if OVERFITTING_TEST: CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50_test" +if VALIDATION_SPLIT: + CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50_val" + +if VALIDATION_SPLIT_TEST: + CHECKPOINT_PATH = "s3://models/soccernet_legibility/resnet_50_val_test" + # %% device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Running on CPU.") if device.type == "cpu" else print("Running on GPU.") @@ -64,11 +73,24 @@ transforms=transforms, ) +if hasattr(os, "register_at_fork") and hasattr(fsspec, "asyn"): + os.register_at_fork(after_in_child=fsspec.asyn.reset_lock) + if OVERFITTING_TEST: train_dataset = Subset(train_dataset, indices=[0]) -if hasattr(os, "register_at_fork") and hasattr(fsspec, "asyn"): - os.register_at_fork(after_in_child=fsspec.asyn.reset_lock) +if VALIDATION_SPLIT_TEST: + train_dataset = Subset(train_dataset, indices=list(range(100))) + +n_images = len(train_dataset) + +if VALIDATION_SPLIT: + split_point = int(0.8 * n_images) + train_idx = list(range(split_point)) + val_idx = list(range(split_point, n_images)) + + val_dataset = Subset(train_dataset, indices=val_idx) + train_dataset = Subset(train_dataset, indices=train_idx) # %% if is_notebook(): @@ -99,6 +121,10 @@ n_epochs = 10 save_epochs = 10 +if VALIDATION_SPLIT_TEST: + n_epochs = 5 + save_epochs = 5 + # %% train_dataloader = DataLoader( train_dataset, @@ -109,6 +135,15 @@ persistent_workers=True, ) +if VALIDATION_SPLIT: + val_dataloader = DataLoader( + val_dataset, + batch_size, + num_workers=DATALOADER_WORKERS, + collate_fn=collate_fn, + persistent_workers=True, + ) + # %% params = [p for p in resnet.parameters() if p.requires_grad] optimizer = optim.AdamW(params, lr=0.001, weight_decay=0.0001) @@ -156,6 +191,9 @@ storage_options=storage_options, ) + if VALIDATION_SPLIT: + metrics_dict = validation_loop(resnet, val_dataloader, device) + epoch += 1 # %% diff --git a/solutions/sports/numbers/resnet_pipeline.py b/solutions/sports/numbers/resnet_pipeline.py index 1a4ebb9..516f830 100644 --- a/solutions/sports/numbers/resnet_pipeline.py +++ b/solutions/sports/numbers/resnet_pipeline.py @@ -1,5 +1,15 @@ +import keras import torch +import torch.nn as nn +import torch.nn.functional as F import torchvision.transforms.v2.functional as tvf +from torch.utils.data import DataLoader +from torchmetrics import AUROC +from torchmetrics import Accuracy +from torchmetrics import F1Score +from torchmetrics import MetricCollection +from torchmetrics import Precision +from torchmetrics import Recall def transforms(image, target): @@ -19,3 +29,57 @@ def transforms(image, target): def collate_fn(batch): return tuple(zip(*batch)) + + +def validation_loop( + resnet: nn.Module, + val_dataloader: DataLoader, + device: torch.device, +): + criterion = nn.BCEWithLogitsLoss() + + metrics = MetricCollection( + [ + Accuracy(task="binary"), + Precision(task="binary"), + Recall(task="binary"), + F1Score(task="binary"), + AUROC(task="binary"), + ] + ) + + resnet.eval() + + steps = 1 + pbar = keras.utils.Progbar(len(val_dataloader)) + + with torch.no_grad(): + for images, targets in val_dataloader: + images = torch.stack(images, dim=0).to(device) + + targets = torch.stack(targets, dim=0) + targets_hot = F.one_hot(targets) + targets_hot = targets_hot.to(device, dtype=torch.float32) + targets = targets.to(device, dtype=torch.float32) + + outputs = resnet(images) + predictions = torch.argmax(outputs, dim=1) + loss = criterion(outputs, targets_hot) + + metrics_dict = metrics.forward(predictions, targets) + + pbar.update( + current=steps, + values=[("loss", loss.item()), *list(metrics_dict.items())], + ) + + steps += 1 + + resnet.train() + + means = {} + for k, v in pbar._values.items(): + acc = v[0].item() if torch.is_tensor(v[0]) else v[0] + means[k] = acc / max(1, v[1]) + + return means