diff --git a/docs/source/recipes/torch-dataset-examples/ddp_train.py b/docs/source/recipes/torch-dataset-examples/ddp_train.py new file mode 100644 index 0000000000..90a1716c5f --- /dev/null +++ b/docs/source/recipes/torch-dataset-examples/ddp_train.py @@ -0,0 +1,249 @@ +from argparse import ArgumentParser +import os + +import fiftyone as fo +from fiftyone.utils.torch import all_gather, local_broadcast_process_authkey + +import torch +from tqdm import tqdm +import numpy as np + +import utils + + +def main(local_rank, dataset_name, num_classes, num_epochs, save_dir): + + torch.distributed.init_process_group() + + # setup local groups + local_group = None + for n in range( + int( + int(os.environ["WORLD_SIZE"]) / int(os.environ["LOCAL_WORLD_SIZE"]) + ) + ): + aux = torch.distributed.new_group() + torch.distributed.barrier() + if int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]) == n: + local_group = aux + local_broadcast_process_authkey(local_group) + + model = utils.setup_ddp_model(num_classes=num_classes) + model.to(DEVICES[local_rank]) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[DEVICES[local_rank]] + ) + + loss_function = torch.nn.CrossEntropyLoss(reduction="none") + + dataset = None + # synchronously load dataset in each trainer + for r in range(int(os.environ["LOCAL_WORLD_SIZE"])): + if local_rank == r: + dataset = fo.load_dataset(dataset_name) + torch.distributed.barrier(local_group) + + dataloaders = utils.create_dataloaders_ddp( + dataset, + utils.mnist_get_item, + local_process_group=local_group, + num_workers=4, + batch_size=16, + persistent_workers=True, + ) + optimizer = utils.setup_optim(ddp_model) + + best_epoch = None + best_loss = np.inf + for epoch in range(num_epochs): + train_epoch( + local_rank, + ddp_model, + dataloaders["train"], + loss_function, + optimizer, + ) + validation_loss = validation( + local_rank, + ddp_model, + dataloaders["validation"], + dataset, + loss_function, + ) + + # average over all trainers + validation_loss = np.mean(all_gather(validation_loss)) + + if validation_loss < best_loss: + best_loss = validation_loss + best_epoch = epoch + if local_rank == 0: + print(f"New best lost achieved : {best_loss}. Saving model...") + torch.save(model.state_dict(), f"{save_dir}/epoch_{epoch}.pt") + + torch.distributed.barrier() + + model = utils.setup_ddp_model( + num_classes=num_classes, + weights_path=f"{save_dir}/epoch_{best_epoch}.pt", + ).to(DEVICES[local_rank]) + model.to(DEVICES[local_rank]) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[DEVICES[local_rank]] + ) + test_loss = validation( + local_rank, + ddp_model, + dataloaders["test"], + dataset, + loss_function, + save_results=True, + ) + test_loss = np.mean(all_gather(test_loss)) + classes = [ + utils.mnist_index_to_label_string(i) for i in range(num_classes) + ] + if local_rank == 0: + results = dataset.match_tags("test").evaluate_classifications( + "predictions", + gt_field="ground_truth", + eval_key="eval", + classes=classes, + k=3, + ) + + print("Final Test Results:") + print(f"Loss = {test_loss}") + results.print_report(classes=classes) + + torch.distributed.destroy_process_group(torch.distributed.group.WORLD) + + +def train_epoch(local_rank, model, dataloader, loss_function, optimizer): + model.train() + + cummulative_loss = 0 + pbar = ( + tqdm(enumerate(dataloader), total=len(dataloader)) + if local_rank == 0 + else enumerate(dataloader) + ) + for batch_num, batch in pbar: + batch["image"] = batch["image"].to(DEVICES[local_rank]) + batch["label"] = batch["label"].to(DEVICES[local_rank]) + + prediction = model(batch["image"]) + loss = torch.mean(loss_function(prediction, batch["label"])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + cummulative_loss = cummulative_loss + loss.detach().cpu().numpy() + if local_rank == 0: + if batch_num % 100 == 0: + pbar.set_description( + f"Average Train Loss = {cummulative_loss / (batch_num + 1):10f}" + ) + return cummulative_loss / (batch_num + 1) + + +@torch.no_grad() +def validation( + local_rank, model, dataloader, dataset, loss_function, save_results=False +): + model.eval() + + cummulative_loss = 0 + pbar = ( + tqdm(enumerate(dataloader), total=len(dataloader)) + if local_rank == 0 + else enumerate(dataloader) + ) + for batch_num, batch in pbar: + with torch.no_grad(): + batch["image"] = batch["image"].to(DEVICES[local_rank]) + batch["label"] = batch["label"].to(DEVICES[local_rank]) + + prediction = model(batch["image"]) + loss_individual = ( + loss_function(prediction, batch["label"]) + .detach() + .cpu() + .numpy() + ) + + if save_results: + samples = dataset._dataset.select(batch["id"]) + samples.set_values("loss", loss_individual.tolist()) + + fo_predictions = [ + fo.Classification( + label=utils.mnist_index_to_label_string( + np.argmax(sample_logits) + ), + logits=sample_logits, + ) + for sample_logits in prediction.detach().cpu().numpy() + ] + samples.set_values("predictions", fo_predictions) + samples.save() + + cummulative_loss = cummulative_loss + np.mean(loss_individual) + if local_rank == 0: + if batch_num % 100 == 0: + pbar.set_description( + f"Average Validation Loss = {cummulative_loss / (batch_num + 1):10f}" + ) + return cummulative_loss / (batch_num + 1) + + +if __name__ == "__main__": + + # run with + # torchrun --nnodes=1 --nproc-per-node=6 \ + # PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \ + # -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7 + + argparser = ArgumentParser() + argparser.add_argument( + "-d", "--dataset", type=str, help="name of fiftyone dataset" + ) + argparser.add_argument( + "-n", + "--num_classes", + type=int, + help="number of classes in the dataset", + ) + argparser.add_argument( + "-e", + "--epochs", + type=int, + help="number of epochs to train for", + default=5, + ) + argparser.add_argument( + "-s", + "--save_dir", + type=str, + help="directory to save checkpoints to", + default="~/mnist_weights", + ) + argparser.add_argument( + "--devices", default=range(torch.cuda.device_count()), nargs="*" + ) + + args = argparser.parse_args() + + assert int(os.environ["LOCAL_WORLD_SIZE"]) == len(args.devices) + + DEVICES = [torch.device(f"cuda:{d}") for d in args.devices] + + local_rank = int(os.environ["LOCAL_RANK"]) + + torch.multiprocessing.set_start_method("forkserver") + torch.multiprocessing.set_forkserver_preload(["torch", "fiftyone"]) + + main( + local_rank, args.dataset, args.num_classes, args.epochs, args.save_dir + ) diff --git a/docs/source/recipes/torch-dataset-examples/utils.py b/docs/source/recipes/torch-dataset-examples/utils.py index 1648cde628..53bcfebc18 100644 --- a/docs/source/recipes/torch-dataset-examples/utils.py +++ b/docs/source/recipes/torch-dataset-examples/utils.py @@ -101,9 +101,7 @@ def mnist_index_to_label_string(index): convert_and_normalize = transforms.Compose( - [ - transforms.ToTensor(), - ] + [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)] ) @@ -142,7 +140,7 @@ def create_dataloaders( def setup_model(num_classes, weights_path=None): model = resnet18(weights=ResNet18_Weights.DEFAULT) linear_head = torch.nn.Linear(512, num_classes) - torch.nn.init.xavier_uniform(linear_head.weight) + torch.nn.init.xavier_uniform_(linear_head.weight) model.fc = linear_head if weights_path is not None: model.load_state_dict(torch.load(weights_path, weights_only=True)) @@ -154,6 +152,39 @@ def setup_optim(model, lr=0.01, l2=0.00001): return optimizer +### DDP utils ### + + +def setup_ddp_model(**kwargs): + model = setup_model(**kwargs) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + return model + + +def create_dataloaders_ddp( + dataset, get_item, cache_fields=None, local_process_group=None, **kwargs +): + split_tags = ["train", "validation", "test"] + dataloaders = {} + for split_tag in split_tags: + split = dataset.match_tags(split_tag).to_torch( + get_item, + cache_fields=cache_fields, + local_process_group=local_process_group, + ) + shuffle = True if split_tag == "train" else False + dataloader = torch.utils.data.DataLoader( + split, + worker_init_fn=FiftyOneTorchDataset.worker_init, + sampler=torch.utils.data.DistributedSampler( + split, shuffle=shuffle + ), + **kwargs, + ) + dataloaders[split_tag] = dataloader + return dataloaders + + if __name__ == "__main__": # this is just here to multiprocessing works when we call these functions in a notebook pass