diff --git a/pathology/multiple_instance_learning/README.md b/pathology/multiple_instance_learning/README.md new file mode 100644 index 0000000000..bca59bbe2c --- /dev/null +++ b/pathology/multiple_instance_learning/README.md @@ -0,0 +1,123 @@ + +# Multiple Instance Learning (MIL) Examples + +This tutorial contains a baseline method of Multiple Instance Learning (MIL) classification from Whole Slide Images (WSI). +The dataset is from [Prostate cANcer graDe Assessment (PANDA) Challenge - 2020](https://www.kaggle.com/c/prostate-cancer-grade-assessment/) for cancer grade classification from prostate histology WSIs. +The implementation is based on: + +Andriy Myronenko, Ziyue Xu, Dong Yang, Holger Roth, Daguang Xu: "Accounting for Dependencies in Deep Learning Based Multiple Instance Learning for Whole Slide Imaging". In MICCAI (2021). [arXiv](https://arxiv.org/abs/2111.01556) + +![mil_patches](./mil_patches.jpg) +![mil_network](./mil_network.jpg) + +## Requirements + +The script is tested with: + +- `Ubuntu 18.04` | `Python 3.6` | `CUDA 11.0` | `Pytorch 1.10` + +- the default pipeline requires about 16GB memory per gpu + +- it is tested on 4x16gb multi-gpu machine + +## Dependencies and installation + +### MONAI + +Please install the required dependencies + +```bash +pip install tifffile +pip install imagecodecs +``` + +For more information please check out [the installation guide](https://docs.monai.io/en/latest/installation.html). + +### Data + +Prostate biopsy WSI dataset can be downloaded from Prostate cANcer graDe Assessment (PANDA) Challenge on [Kaggle](https://www.kaggle.com/c/prostate-cancer-grade-assessment/data). +In this tutorial, we assume it is downloaded in the `/PandaChallenge2020` folder + +## Examples + +Check all possible options + +```bash +python ./panda_mil_train_evaluate_pytorch_gpu.py -h +``` + +### Train + +Train in multi-gpu mode with AMP using all available gpus, +assuming the training images in /PandaChallenge2020/train_images folder, +it will use the pre-defined 80/20 data split in [datalist_panda_0.json](https://drive.google.com/drive/u/0/folders/1CAHXDZqiIn5QUfg5A7XsK1BncRu6Ftbh) + +```bash +python -u panda_mil_train_evaluate_pytorch_gpu.py + --data_root=/PandaChallenge2020/train_images \ + --amp \ + --distributed \ + --mil_mode=att_trans \ + --batch_size=4 \ + --epochs=50 \ + --logdir=./logs +``` + +If you need to use only specific gpus, simply add the prefix `CUDA_VISIBLE_DEVICES=...` + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -u panda_mil_train_evaluate_pytorch_gpu.py + --data_root=/PandaChallenge2020/train_images \ + --amp \ + --distributed \ + --mil_mode=att_trans \ + --batch_size=4 \ + --epochs=50 \ + --logdir=./logs +``` + +### Validation + +Run inference of the best checkpoint over the validation set + +```bash +# Validate checkpoint on a single gpu +python -u panda_mil_train_evaluate_pytorch_gpu.py + --data_root=/PandaChallenge2020/train_images \ + --amp \ + --mil_mode=att_trans \ + --checkpoint=./logs/model.pt \ + --validate +``` + +### Inference + +Run inference on a different dataset. It's the same script as for validation, +we just specify a different data_root and json list files + +```bash +python -u panda_mil_train_evaluate_pytorch_gpu.py + --data_root=/PandaChallenge2020/some_other_files \ + --dataset_json=some_other_files.json + --amp \ + --mil_mode=att_trans \ + --checkpoint=./logs/model.pt \ + --validate +``` + +### Stats + +Expected train and validation loss curves + +![mil_train_loss](./mil_train_loss.png) +![mil_val_loss](./mil_val_loss.png) + +Expected validation QWK metric + +![mil_val_qwk](./mil_val_qwk.png) + +## Questions and bugs + +- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. +- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). +- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues). diff --git a/pathology/multiple_instance_learning/mil_network.jpg b/pathology/multiple_instance_learning/mil_network.jpg new file mode 100644 index 0000000000..01c31c0434 Binary files /dev/null and b/pathology/multiple_instance_learning/mil_network.jpg differ diff --git a/pathology/multiple_instance_learning/mil_patches.jpg b/pathology/multiple_instance_learning/mil_patches.jpg new file mode 100644 index 0000000000..d4c24ca401 Binary files /dev/null and b/pathology/multiple_instance_learning/mil_patches.jpg differ diff --git a/pathology/multiple_instance_learning/mil_train_loss.png b/pathology/multiple_instance_learning/mil_train_loss.png new file mode 100644 index 0000000000..844f933dfd Binary files /dev/null and b/pathology/multiple_instance_learning/mil_train_loss.png differ diff --git a/pathology/multiple_instance_learning/mil_val_loss.png b/pathology/multiple_instance_learning/mil_val_loss.png new file mode 100644 index 0000000000..4ff4c9aa3f Binary files /dev/null and b/pathology/multiple_instance_learning/mil_val_loss.png differ diff --git a/pathology/multiple_instance_learning/mil_val_qwk.png b/pathology/multiple_instance_learning/mil_val_qwk.png new file mode 100644 index 0000000000..dc4b18b52c Binary files /dev/null and b/pathology/multiple_instance_learning/mil_val_qwk.png differ diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py new file mode 100644 index 0000000000..c37bd5196d --- /dev/null +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -0,0 +1,544 @@ +import os +import time +import shutil +import argparse +import collections.abc +import gdown + +import numpy as np +from sklearn.metrics import cohen_kappa_score + +import torch +import torch.nn as nn +from torch.cuda.amp import GradScaler, autocast + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.dataloader import default_collate + +import torch.distributed as dist +import torch.multiprocessing as mp + +from monai.data import Dataset, load_decathlon_datalist +from monai.data.image_reader import WSIReader +from monai.metrics import Cumulative, CumulativeAverage +from monai.transforms import Transform, Compose, LoadImageD, RandFlipd, RandRotate90d, ScaleIntensityRangeD, ToTensord +from monai.apps.pathology.transforms import TileOnGridd +from monai.networks.nets import milmodel + + +def parse_args(): + + parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.") + parser.add_argument( + "--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images" + ) + parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file") + + parser.add_argument("--num_classes", default=5, type=int, help="number of output classes") + parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm") + parser.add_argument( + "--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image" + ) + parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels") + + parser.add_argument("--checkpoint", default=None, help="load existing checkpoint") + parser.add_argument( + "--validate", + action="store_true", + help="run only inference on the validation set, must specify the checkpoint argument", + ) + + parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs") + + parser.add_argument("--epochs", default=50, type=int, help="number of training epochs") + parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu") + parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate") + + parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay") + parser.add_argument("--amp", action="store_true", help="use AMP, recommended") + parser.add_argument( + "--val_every", + default=1, + type=int, + help="run validation after this number of epochs, default 1 to run every epoch", + ) + parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading") + + ###for multigpu + parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended") + parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") + parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") + parser.add_argument( + "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training" + ) + parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") + + parser.add_argument( + "--quick", action="store_true", help="use a small subset of data for debugging" + ) # for debugging + + args = parser.parse_args() + + print("Argument values:") + for k, v in vars(args).items(): + print(k, "=>", v) + print("-----------------") + + return args + + +def train_epoch(model, loader, optimizer, scaler, epoch, args): + """One train epoch over the dataset""" + + model.train() + criterion = nn.BCEWithLogitsLoss() + + run_loss = CumulativeAverage() + run_acc = CumulativeAverage() + + start_time = time.time() + loss, acc = 0.0, 0.0 + + for idx, batch_data in enumerate(loader): + + data, target = batch_data["image"].cuda(args.rank), batch_data["label"].cuda(args.rank) + + optimizer.zero_grad(set_to_none=True) + + with autocast(enabled=args.amp): + logits = model(data) + loss = criterion(logits, target) + + if args.amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + acc = (logits.sigmoid().sum(1).detach().round() == target.sum(1).round()).float().mean() + + run_loss.append(loss) + run_acc.append(acc) + + loss = run_loss.aggregate() + acc = run_acc.aggregate() + + if args.rank == 0: + print( + "Epoch {}/{} {}/{}".format(epoch, args.epochs, idx, len(loader)), + "loss: {:.4f}".format(loss), + "acc: {:.4f}".format(acc), + "time {:.2f}s".format(time.time() - start_time), + ) + start_time = time.time() + + return loss, acc + + +def val_epoch(model, loader, epoch, args, max_tiles=None): + """One validation epoch over the dataset""" + + model.eval() + + model2 = model if not args.distributed else model.module + has_extra_outputs = model2.mil_mode == "att_trans_pyramid" + extra_outputs = model2.extra_outputs + calc_head = model2.calc_head + + criterion = nn.BCEWithLogitsLoss() + + run_loss = CumulativeAverage() + run_acc = CumulativeAverage() + PREDS = Cumulative() + TARGETS = Cumulative() + + start_time = time.time() + loss, acc = 0.0, 0.0 + + with torch.no_grad(): + + for idx, batch_data in enumerate(loader): + + data, target = batch_data["image"].cuda(args.rank), batch_data["label"].cuda(args.rank) + + with autocast(enabled=args.amp): + + if max_tiles is not None and data.shape[1] > max_tiles: + # During validation, we want to use all instances/patches + # and if its number is very big, we may run out of GPU memory + # in this case, we first iteratively go over subsets of patches to calculate backbone features + # and at the very end calculate the classification output + + logits = [] + logits2 = [] + + for i in range(int(np.ceil(data.shape[1] / float(max_tiles)))): + data_slice = data[:, i * max_tiles : (i + 1) * max_tiles] + logits_slice = model(data_slice, no_head=True) + logits.append(logits_slice) + + if has_extra_outputs: + logits2.append( + [ + extra_outputs["layer1"], + extra_outputs["layer2"], + extra_outputs["layer3"], + extra_outputs["layer4"], + ] + ) + + logits = torch.cat(logits, dim=1) + if has_extra_outputs: + extra_outputs["layer1"] = torch.cat([l[0] for l in logits2], dim=0) + extra_outputs["layer2"] = torch.cat([l[1] for l in logits2], dim=0) + extra_outputs["layer3"] = torch.cat([l[2] for l in logits2], dim=0) + extra_outputs["layer4"] = torch.cat([l[3] for l in logits2], dim=0) + + logits = calc_head(logits) + + else: + # if number of instances is not big, we can run inference directly + logits = model(data) + + loss = criterion(logits, target) + + pred = logits.sigmoid().sum(1).detach().round() + target = target.sum(1).round() + acc = (pred == target).float().mean() + + run_loss.append(loss) + run_acc.append(acc) + loss = run_loss.aggregate() + acc = run_acc.aggregate() + + PREDS.extend(pred) + TARGETS.extend(target) + + if args.rank == 0: + print( + "Val epoch {}/{} {}/{}".format(epoch, args.epochs, idx, len(loader)), + "loss: {:.4f}".format(loss), + "acc: {:.4f}".format(acc), + "time {:.2f}s".format(time.time() - start_time), + ) + start_time = time.time() + + # Calculate QWK metric (Quadratic Weigted Kappa) https://en.wikipedia.org/wiki/Cohen%27s_kappa + PREDS = PREDS.get_buffer().cpu().numpy() + TARGETS = TARGETS.get_buffer().cpu().numpy() + qwk = cohen_kappa_score(PREDS.astype(np.float64), TARGETS.astype(np.float64), weights="quadratic") + + return loss, acc, qwk + + +def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0): + """Save checkpoint""" + + state_dict = model.state_dict() if not args.distributed else model.module.state_dict() + + save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} + + filename = os.path.join(args.logdir, filename) + torch.save(save_dict, filename) + print("Saving checkpoint", filename) + + +class LabelEncodeIntegerGraded(Transform): + """ + Convert an integer label to encoded array representation of length num_classes, + with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5, + embedding of 2 -> (1,1,0,0,0) + + Args: + num_classes: the number of classes to convert to encoded format. + keys: keys of the corresponding items to be transformed + Defaults to ``['label']``. + + """ + + def __init__(self, num_classes, keys=["label"]): + super().__init__() + self.keys = keys + self.num_classes = num_classes + + def __call__(self, data): + + d = dict(data) + for key in self.keys: + label = int(d[key]) + + lz = np.zeros(self.num_classes, dtype=np.float32) + lz[:label] = 1.0 + # alternative oneliner lz=(np.arange(self.num_classes) {'image' : Nx3xHxW} + followed by the default collate which will form a batch BxNx3xHxW + ''' + + for i, item in enumerate(batch): + data = item[0] + data["image"] = torch.stack([ix["image"] for ix in item], dim=0) + batch[i] = data + return default_collate(batch) + + +def main_worker(gpu, args): + + args.gpu = gpu + + if args.distributed: + args.rank = args.rank * torch.cuda.device_count() + gpu + dist.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + + print(args.rank, " gpu", args.gpu) + + torch.cuda.set_device(args.gpu) # use this default device (same as args.device if not distributed) + torch.backends.cudnn.benchmark = True + + if args.rank == 0: + print("Batch size is:", args.batch_size, "epochs", args.epochs) + + ############# + # Create MONAI dataset + training_list = load_decathlon_datalist( + data_list_file_path=args.dataset_json, + data_list_key="training", + base_dir=args.data_root, + ) + validation_list = load_decathlon_datalist( + data_list_file_path=args.dataset_json, + data_list_key="validation", + base_dir=args.data_root, + ) + + if args.quick: # for debugging on a small subset + training_list = training_list[:16] + validation_list = validation_list[:16] + + train_transform = Compose( + [ + LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), + LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), + TileOnGridd( + keys=["image"], + tile_count=args.tile_count, + tile_size=args.tile_size, + random_offset=True, + background_val=255, + return_list_of_dicts=True, + ), + RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), + RandFlipd(keys=["image"], spatial_axis=1, prob=0.5), + RandRotate90d(keys=["image"], prob=0.5), + ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), + ToTensord(keys=["image", "label"]), + ] + ) + + valid_transform = Compose( + [ + LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), + LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), + TileOnGridd( + keys=["image"], + tile_count=None, + tile_size=args.tile_size, + random_offset=False, + background_val=255, + return_list_of_dicts=True, + ), + ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), + ToTensord(keys=["image", "label"]), + ] + ) + + dataset_train = Dataset(data=training_list, transform=train_transform) + dataset_valid = Dataset(data=validation_list, transform=valid_transform) + + train_sampler = DistributedSampler(dataset_train) if args.distributed else None + val_sampler = DistributedSampler(dataset_valid, shuffle=False) if args.distributed else None + + train_loader = torch.utils.data.DataLoader( + dataset_train, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + pin_memory=True, + multiprocessing_context="spawn", + sampler=train_sampler, + collate_fn=list_data_collate, + ) + valid_loader = torch.utils.data.DataLoader( + dataset_valid, + batch_size=1, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + multiprocessing_context="spawn", + sampler=val_sampler, + collate_fn=list_data_collate, + ) + + if args.rank == 0: + print("Dataset training:", len(dataset_train), "validation:", len(dataset_valid)) + + model = milmodel.MILModel(num_classes=args.num_classes, pretrained=True, mil_mode=args.mil_mode) + + best_acc = 0 + start_epoch = 0 + if args.checkpoint is not None: + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["state_dict"]) + if "epoch" in checkpoint: + start_epoch = checkpoint["epoch"] + if "best_acc" in checkpoint: + best_acc = checkpoint["best_acc"] + print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc)) + + model.cuda(args.gpu) + + if args.distributed: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu) + + if args.validate: + # if we only want to validate existing checkpoint + epoch_time = time.time() + val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=0, args=args, max_tiles=args.tile_count) + if args.rank == 0: + print( + "Final validation loss: {:.4f}".format(val_loss), + "acc: {:.4f}".format(val_acc), + "qwk: {:.4f}".format(qwk), + "time {:.2f}s".format(time.time() - epoch_time), + ) + + exit(0) + + params = model.parameters() + + if args.mil_mode in ["att_trans", "att_trans_pyramid"]: + m = model if not args.distributed else model.module + params = [ + {"params": list(m.attention.parameters()) + list(m.myfc.parameters()) + list(m.net.parameters())}, + {"params": list(m.transformer.parameters()), "lr": 6e-6, "weight_decay": 0.1}, + ] + + optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0) + + if args.logdir is not None and args.rank == 0: + writer = SummaryWriter(log_dir=args.logdir) + if args.rank == 0: + print("Writing Tensorboard logs to ", writer.log_dir) + else: + writer = None + + ###RUN TRAINING + n_epochs = args.epochs + val_acc_max = 0.0 + + scaler = None + if args.amp: # new native amp + scaler = GradScaler() + + for epoch in range(start_epoch, n_epochs): + + if args.distributed: + train_sampler.set_epoch(epoch) + torch.distributed.barrier() + + print(args.rank, time.ctime(), "Epoch:", epoch) + + epoch_time = time.time() + train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args) + + if args.rank == 0: + print( + "Final training {}/{}".format(epoch, n_epochs - 1), + "loss: {:.4f}".format(train_loss), + "acc: {:.4f}".format(train_acc), + "time {:.2f}s".format(time.time() - epoch_time), + ) + + if args.rank == 0 and writer is not None: + writer.add_scalar("train_loss", train_loss, epoch) + writer.add_scalar("train_acc", train_acc, epoch) + + if args.distributed: + torch.distributed.barrier() + + b_new_best = False + val_acc = 0 + if (epoch + 1) % args.val_every == 0: + + epoch_time = time.time() + val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=epoch, args=args, max_tiles=args.tile_count) + if args.rank == 0: + print( + "Final validation {}/{}".format(epoch, n_epochs - 1), + "loss: {:.4f}".format(val_loss), + "acc: {:.4f}".format(val_acc), + "qwk: {:.4f}".format(qwk), + "time {:.2f}s".format(time.time() - epoch_time), + ) + if writer is not None: + writer.add_scalar("val_loss", val_loss, epoch) + writer.add_scalar("val_acc", val_acc, epoch) + writer.add_scalar("val_qwk", qwk, epoch) + + val_acc = qwk + + if val_acc > val_acc_max: + print("qwk ({:.6f} --> {:.6f})".format(val_acc_max, val_acc)) + val_acc_max = val_acc + b_new_best = True + + if args.rank == 0 and args.logdir is not None: + save_checkpoint(model, epoch, args, best_acc=val_acc, filename="model_final.pt") + if b_new_best: + print("Copying to model.pt new best model!!!!") + shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt")) + + scheduler.step() + + print("ALL DONE") + + +if __name__ == "__main__": + main()