Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add PyTorch Profiler to analyze training process #937

Merged
merged 6 commits into from
Apr 27, 2021

Conversation

zhouzaida
Copy link
Collaborator

@zhouzaida zhouzaida commented Apr 11, 2021

Profiler is a tool that allows the collection of the performance metrics during the training and inference. More details on Profiler can be found at https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

from mmcv.parallel import MMDataParallel
from mmcv.runner import EpochBasedRunner
from mmcv.utils import get_logger


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
    def train_step(self, data, optimizer):  # optimizer is unused and reversed.
        images, labels = data
        predicts = self(images)  # -> self.__call__() -> self.forward()
        loss = F.cross_entropy(predicts, labels)
        return {'loss': loss}


# instantiate a model
model = Model()
# wrap model by MMDataParallel
model = MMDataParallel(model.cuda())

# dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10(root='data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

# runner
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
runner = EpochBasedRunner(model, optimizer=optimizer, work_dir='./work_dir',
                          logger=get_logger('mmcv'), max_epochs=1)

# optimizer hook
optimizer_config = dict(grad_clip=None)
# log hook
lr_config = dict(policy='step', step=[2, 3])
checkpoint_config = dict(interval=1)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
runner.register_training_hooks(lr_config=lr_config,
                               optimizer_config=optimizer_config,
                               checkpoint_config=checkpoint_config,
                               log_config=log_config)
# profiler hook
trace_config = dict(type='tb_trace', dir_name='./work_dir')
profiler_config = dict(on_trace_ready=trace_config, record_shapes=True)
runner.register_profiler_hook(profiler_config)

# training
runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
> tensorboard --logdir=work_dir
  • num_worker=2

localhost_6006_

localhost_6006_ (1)

Modify the num_worker according to the Performance Recommendation.

  • num_worker=4

localhost_6006_ (2)

  • num_worker=8

image

As we can see, the consumption time of DataLoader is reduced.

Reference:

@codecov
Copy link

codecov bot commented Apr 11, 2021

Codecov Report

Merging #937 (e04c5b7) into master (0dd0c49) will decrease coverage by 1.20%.
The diff coverage is 14.94%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #937      +/-   ##
==========================================
- Coverage   65.73%   64.53%   -1.21%     
==========================================
  Files         150      152       +2     
  Lines        9517     9761     +244     
  Branches     1726     1776      +50     
==========================================
+ Hits         6256     6299      +43     
- Misses       2938     3136     +198     
- Partials      323      326       +3     
Flag Coverage Δ
unittests 64.53% <14.94%> (-1.21%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/runner/base_runner.py 70.58% <12.50%> (-2.03%) ⬇️
mmcv/runner/hooks/profiler.py 14.10% <14.10%> (ø)
mmcv/runner/hooks/__init__.py 100.00% <100.00%> (ø)
mmcv/runner/hooks/optimizer.py 17.74% <0.00%> (-6.01%) ⬇️
mmcv/runner/fp16_utils.py 59.73% <0.00%> (-1.59%) ⬇️
mmcv/utils/registry.py 98.31% <0.00%> (-0.02%) ⬇️
mmcv/cnn/builder.py 100.00% <0.00%> (ø)
mmcv/ops/upfirdn2d.py 15.29% <0.00%> (ø)
mmcv/cnn/bricks/transformer.py 0.00% <0.00%> (ø)
mmcv/ops/fused_bias_leakyrelu.py 30.90% <0.00%> (ø)
... and 2 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0dd0c49...e04c5b7. Read the comment docs.

@hellock hellock requested a review from xvjiarui April 17, 2021 02:15
@zhouzaida
Copy link
Collaborator Author

ping @xvjiarui

@zhouzaida zhouzaida requested review from ZwwWayne and hellock April 24, 2021 05:54
@ZwwWayne ZwwWayne merged commit c142ece into open-mmlab:master Apr 27, 2021
@zhouzaida zhouzaida deleted the torch_profiler branch April 27, 2021 11:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants