diff --git a/docs/source/recipes/torch-dataset-examples/ddp_train.py b/docs/source/recipes/torch-dataset-examples/ddp_train.py index 90a1716c5f..97e45f3875 100644 --- a/docs/source/recipes/torch-dataset-examples/ddp_train.py +++ b/docs/source/recipes/torch-dataset-examples/ddp_train.py @@ -2,7 +2,7 @@ import os import fiftyone as fo -from fiftyone.utils.torch import all_gather, local_broadcast_process_authkey +from fiftyone.utils.torch import all_gather, FiftyOneTorchDataset import torch from tqdm import tqdm @@ -15,18 +15,14 @@ def main(local_rank, dataset_name, num_classes, num_epochs, save_dir): torch.distributed.init_process_group() - # setup local groups - local_group = None - for n in range( - int( - int(os.environ["WORLD_SIZE"]) / int(os.environ["LOCAL_WORLD_SIZE"]) - ) - ): - aux = torch.distributed.new_group() - torch.distributed.barrier() - if int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]) == n: - local_group = aux - local_broadcast_process_authkey(local_group) + #### START FIFTYONE DISTRIBUTED INIT CODE #### + local_group = torch.distributed.new_group() + torch.distributed.barrier() + + dataset = FiftyOneTorchDataset.distributed_init( + dataset_name, local_process_group=local_group + ) + #### END FIFTYONE DISTRIBUTED INIT CODE #### model = utils.setup_ddp_model(num_classes=num_classes) model.to(DEVICES[local_rank]) @@ -36,13 +32,6 @@ def main(local_rank, dataset_name, num_classes, num_epochs, save_dir): loss_function = torch.nn.CrossEntropyLoss(reduction="none") - dataset = None - # synchronously load dataset in each trainer - for r in range(int(os.environ["LOCAL_WORLD_SIZE"])): - if local_rank == r: - dataset = fo.load_dataset(dataset_name) - torch.distributed.barrier(local_group) - dataloaders = utils.create_dataloaders_ddp( dataset, utils.mnist_get_item, @@ -116,7 +105,8 @@ def main(local_rank, dataset_name, num_classes, num_epochs, save_dir): print(f"Loss = {test_loss}") results.print_report(classes=classes) - torch.distributed.destroy_process_group(torch.distributed.group.WORLD) + torch.distributed.destroy_process_group(local_group) + torch.distributed.destroy_process_group() def train_epoch(local_rank, model, dataloader, loss_function, optimizer): @@ -200,10 +190,10 @@ def validation( if __name__ == "__main__": - # run with - # torchrun --nnodes=1 --nproc-per-node=6 \ - # PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \ - # -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7 + """run with + torchrun --nnodes=1 --nproc-per-node=6 \ + PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \ + -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7""" argparser = ArgumentParser() argparser.add_argument( diff --git a/docs/source/recipes/torch-dataset-examples/utils.py b/docs/source/recipes/torch-dataset-examples/utils.py index c142c7a30c..5c2ff47856 100644 --- a/docs/source/recipes/torch-dataset-examples/utils.py +++ b/docs/source/recipes/torch-dataset-examples/utils.py @@ -120,14 +120,18 @@ def mnist_get_item(sample): def create_dataloaders( - dataset, get_item, cache_fields=None, local_process_group=None, **kwargs + dataset, + get_item, + cache_field_names=None, + local_process_group=None, + **kwargs, ): split_tags = ["train", "validation", "test"] dataloaders = {} for split_tag in split_tags: split = dataset.match_tags(split_tag).to_torch( get_item, - cache_fields=cache_fields, + cache_field_names=cache_field_names, local_process_group=local_process_group, ) shuffle = True if split_tag == "train" else False @@ -166,14 +170,18 @@ def setup_ddp_model(**kwargs): def create_dataloaders_ddp( - dataset, get_item, cache_fields=None, local_process_group=None, **kwargs + dataset, + get_item, + cache_field_names=None, + local_process_group=None, + **kwargs, ): split_tags = ["train", "validation", "test"] dataloaders = {} for split_tag in split_tags: split = dataset.match_tags(split_tag).to_torch( get_item, - cache_fields=cache_fields, + cache_field_names=cache_field_names, local_process_group=local_process_group, ) shuffle = True if split_tag == "train" else False