Skip to content

Commit

Permalink
updated ddp example to use distributed_init method and new cache_fiel…
Browse files Browse the repository at this point in the history
…d_names argument
  • Loading branch information
Jacob Sela committed Jan 16, 2025
1 parent 8a52fdf commit 4448027
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
40 changes: 15 additions & 25 deletions docs/source/recipes/torch-dataset-examples/ddp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

import fiftyone as fo
from fiftyone.utils.torch import all_gather, local_broadcast_process_authkey
from fiftyone.utils.torch import all_gather, FiftyOneTorchDataset

import torch
from tqdm import tqdm
Expand All @@ -15,18 +15,14 @@ def main(local_rank, dataset_name, num_classes, num_epochs, save_dir):

torch.distributed.init_process_group()

# setup local groups
local_group = None
for n in range(
int(
int(os.environ["WORLD_SIZE"]) / int(os.environ["LOCAL_WORLD_SIZE"])
)
):
aux = torch.distributed.new_group()
torch.distributed.barrier()
if int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]) == n:
local_group = aux
local_broadcast_process_authkey(local_group)
#### START FIFTYONE DISTRIBUTED INIT CODE ####
local_group = torch.distributed.new_group()
torch.distributed.barrier()

dataset = FiftyOneTorchDataset.distributed_init(
dataset_name, local_process_group=local_group
)
#### END FIFTYONE DISTRIBUTED INIT CODE ####

model = utils.setup_ddp_model(num_classes=num_classes)
model.to(DEVICES[local_rank])
Expand All @@ -36,13 +32,6 @@ def main(local_rank, dataset_name, num_classes, num_epochs, save_dir):

loss_function = torch.nn.CrossEntropyLoss(reduction="none")

dataset = None
# synchronously load dataset in each trainer
for r in range(int(os.environ["LOCAL_WORLD_SIZE"])):
if local_rank == r:
dataset = fo.load_dataset(dataset_name)
torch.distributed.barrier(local_group)

dataloaders = utils.create_dataloaders_ddp(
dataset,
utils.mnist_get_item,
Expand Down Expand Up @@ -116,7 +105,8 @@ def main(local_rank, dataset_name, num_classes, num_epochs, save_dir):
print(f"Loss = {test_loss}")
results.print_report(classes=classes)

torch.distributed.destroy_process_group(torch.distributed.group.WORLD)
torch.distributed.destroy_process_group(local_group)
torch.distributed.destroy_process_group()


def train_epoch(local_rank, model, dataloader, loss_function, optimizer):
Expand Down Expand Up @@ -200,10 +190,10 @@ def validation(

if __name__ == "__main__":

# run with
# torchrun --nnodes=1 --nproc-per-node=6 \
# PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \
# -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7
"""run with
torchrun --nnodes=1 --nproc-per-node=6 \
PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \
-s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7"""

argparser = ArgumentParser()
argparser.add_argument(
Expand Down
16 changes: 12 additions & 4 deletions docs/source/recipes/torch-dataset-examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,18 @@ def mnist_get_item(sample):


def create_dataloaders(
dataset, get_item, cache_fields=None, local_process_group=None, **kwargs
dataset,
get_item,
cache_field_names=None,
local_process_group=None,
**kwargs,
):
split_tags = ["train", "validation", "test"]
dataloaders = {}
for split_tag in split_tags:
split = dataset.match_tags(split_tag).to_torch(
get_item,
cache_fields=cache_fields,
cache_field_names=cache_field_names,
local_process_group=local_process_group,
)
shuffle = True if split_tag == "train" else False
Expand Down Expand Up @@ -166,14 +170,18 @@ def setup_ddp_model(**kwargs):


def create_dataloaders_ddp(
dataset, get_item, cache_fields=None, local_process_group=None, **kwargs
dataset,
get_item,
cache_field_names=None,
local_process_group=None,
**kwargs,
):
split_tags = ["train", "validation", "test"]
dataloaders = {}
for split_tag in split_tags:
split = dataset.match_tags(split_tag).to_torch(
get_item,
cache_fields=cache_fields,
cache_field_names=cache_field_names,
local_process_group=local_process_group,
)
shuffle = True if split_tag == "train" else False
Expand Down

0 comments on commit 4448027

Please sign in to comment.