Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Add support for Sync BN (#423)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #423

Added support for using sync batch normalization using PyTorch's implementation or Apex's.

Plugged in the model complexity hook to `classy_train.py`. It helps test the bug I encountered and fixed which needs the profiler + sync batch norm.

Reviewed By: vreis

Differential Revision: D20307435

fbshipit-source-id: 30556effe0149a51e2f53970ffb422a6a8ef15be
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 11, 2020
1 parent 7da612b commit 39e54ae
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 8 deletions.
3 changes: 2 additions & 1 deletion classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from classy_vision.hooks import (
CheckpointHook,
LossLrMeterLoggingHook,
ModelComplexityHook,
ProfilerHook,
ProgressBarHook,
TensorboardPlotHook,
Expand Down Expand Up @@ -118,7 +119,7 @@ def main(args, config):


def configure_hooks(args, config):
hooks = [LossLrMeterLoggingHook(args.log_freq)]
hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]

# Make a folder to store checkpoints and tensorboard logging outputs
suffix = datetime.now().isoformat()
Expand Down
43 changes: 41 additions & 2 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, Union

import torch
import torch.nn as nn
from classy_vision.dataset import ClassyDataset, build_dataset
from classy_vision.generic.distributed_util import (
all_reduce_mean,
Expand Down Expand Up @@ -53,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum):
BEFORE_EVAL = enum.auto()


class BatchNormSyncMode(enum.Enum):
DISABLED = enum.auto() # No Synchronized Batch Normalization
PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm
APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed


class LastBatchInfo(NamedTuple):
loss: torch.Tensor
output: torch.Tensor
Expand Down Expand Up @@ -133,6 +140,7 @@ def __init__(self):
self.amp_opt_level = None
self.perf_log = []
self.last_batch = None
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED

def set_checkpoint(self, checkpoint):
"""Sets checkpoint on task.
Expand Down Expand Up @@ -204,14 +212,35 @@ def set_meters(self, meters: List["ClassyMeter"]):
self.meters = meters
return self

def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode):
def set_distributed_options(
self,
broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED,
batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
):
"""Set distributed options.
Args:
broadcast_buffers_mode: Broadcast buffers mode. See
:class:`BroadcastBuffersMode` for options.
batch_norm_sync_mode: Batch normalization synchronization mode. See
:class:`BatchNormSyncMode` for options.
Raises:
RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
is not installed.
"""
self.broadcast_buffers_mode = broadcast_buffers_mode

if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
logging.info("Synchronized Batch Normalization is disabled")
else:
if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
raise RuntimeError("apex is not installed")
logging.info(
f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
)
self.batch_norm_sync_mode = batch_norm_sync_mode

return self

def set_hooks(self, hooks: List["ClassyHook"]):
Expand Down Expand Up @@ -317,7 +346,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_meters(meters)
.set_amp_opt_level(amp_opt_level)
.set_distributed_options(
BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")]
broadcast_buffers_mode=BroadcastBuffersMode[
config.get("broadcast_buffers", "disabled").upper()
],
batch_norm_sync_mode=BatchNormSyncMode[
config.get("batch_norm_sync_mode", "disabled").upper()
],
)
)
for phase_type in phase_types:
Expand Down Expand Up @@ -494,6 +528,11 @@ def prepare(
multiprocessing_context=dataloader_mp_context,
)

if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
self.base_model = apex.parallel.convert_syncbn_model(self.base_model)

# move the model and loss to the right device
if use_gpu:
self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
Expand Down
5 changes: 3 additions & 2 deletions test/generic/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def get_test_mlp_task_config():
"num_classes": 2,
"crop_size": 20,
"class_ratio": 0.5,
"num_samples": 10,
"num_samples": 20,
"seed": 0,
"batchsize_per_replica": 3,
"batchsize_per_replica": 6,
"use_augmentation": False,
"use_shuffle": True,
"transforms": [
Expand Down Expand Up @@ -228,6 +228,7 @@ def get_test_mlp_task_config():
"input_dim": 1200,
"output_dim": 1000,
"hidden_dims": [10],
"use_batchnorm": True, # used for testing sync batchnorm
},
"meters": {"accuracy": {"topk": [1]}},
"optimizer": {
Expand Down
2 changes: 1 addition & 1 deletion test/hooks_loss_lr_meter_logging_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def scheduler_mock(where):
mock_lr_scheduler.update_interval = UpdateInterval.STEP
config = get_test_mlp_task_config()
config["num_epochs"] = 3
config["dataset"]["train"]["batchsize_per_replica"] = 5
config["dataset"]["train"]["batchsize_per_replica"] = 10
config["dataset"]["test"]["batchsize_per_replica"] = 5
task = build_task(config)
task.optimizer.param_schedulers["lr"] = mock_lr_scheduler
Expand Down
4 changes: 2 additions & 2 deletions test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def flush(self):

config = get_test_mlp_task_config()
config["num_epochs"] = 3
config["dataset"]["train"]["batchsize_per_replica"] = 5
config["dataset"]["train"]["batchsize_per_replica"] = 10
config["dataset"]["test"]["batchsize_per_replica"] = 5
task = build_task(config)

Expand All @@ -152,7 +152,7 @@ def flush(self):
trainer = LocalTrainer()
trainer.train(task)

# We have 10 samples, batch size is 5. Each epoch is done in two steps.
# We have 20 samples, batch size is 10. Each epoch is done in two steps.
self.assertEqual(
writer.scalar_logs["train_learning_rate_updates"],
[0, 1 / 6, 2 / 6, 3 / 6, 4 / 6, 5 / 6],
Expand Down
26 changes: 26 additions & 0 deletions test/trainer_distributed_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def setUp(self):
config = get_test_mlp_task_config()
invalid_config = copy.deepcopy(config)
invalid_config["name"] = "invalid_task"
sync_bn_config = copy.deepcopy(config)
sync_bn_config["sync_batch_norm_mode"] = "pytorch"
self.config_files = {}
for config_key, config in [
("config", config),
("invalid_config", invalid_config),
("sync_bn_config", sync_bn_config),
]:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
json.dump(config, f)
Expand Down Expand Up @@ -63,3 +66,26 @@ def test_training(self):
result = subprocess.run(cmd, shell=True)
success = result.returncode == 0
self.assertEqual(success, expected_success)

@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
def test_sync_batch_norm(self):
"""Test that sync batch norm training doesn't hang."""

num_processes = 2
device = "gpu"

cmd = f"""{sys.executable} -m torch.distributed.launch \
--nnodes=1 \
--nproc_per_node={num_processes} \
--master_addr=localhost \
--master_port=29500 \
--use_env \
{self.path}/../classy_train.py \
--device={device} \
--config={self.config_files["sync_bn_config"]} \
--num_workers=4 \
--log_freq=100 \
--distributed_backend=ddp
"""
result = subprocess.run(cmd, shell=True)
self.assertEqual(result.returncode, 0)

0 comments on commit 39e54ae

Please sign in to comment.