Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error : Default process group is not initialized #23

Closed
rassabin opened this issue Jul 16, 2020 · 13 comments
Closed

Error : Default process group is not initialized #23

rassabin opened this issue Jul 16, 2020 · 13 comments
Assignees
Labels

Comments

@rassabin
Copy link

rassabin commented Jul 16, 2020

Torch : 1.4.0

CUDA: 10.0

MMCV : 1.0.2

MMSEG: 0.5.0+1c3f547

small custom dataset

Config :

norm_cfg = dict(type='BN', requires_grad=True)

model = dict(
    type='CascadeEncoderDecoder',
    num_stages=2,
    pretrained='open-mmlab://msra/hrnetv2_w18',
    backbone=dict(
        type='HRNet',
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        norm_eval=False,
        extra=dict(
            stage1=dict(
                num_modules=1,
                num_branches=1,
                block='BOTTLENECK',
                num_blocks=(4, ),
                num_channels=(64, )),
            stage2=dict(
                num_modules=1,
                num_branches=2,
                block='BASIC',
                num_blocks=(4, 4),
                num_channels=(18, 36)),
            stage3=dict(
                num_modules=4,
                num_branches=3,
                block='BASIC',
                num_blocks=(4, 4, 4),
                num_channels=(18, 36, 72)),
            stage4=dict(
                num_modules=3,
                num_branches=4,
                block='BASIC',
                num_blocks=(4, 4, 4, 4),
                num_channels=(18, 36, 72, 144)))),
    decode_head=[
        dict(
            type='FCNHead',
            in_channels=[18, 36, 72, 144],
            channels=270,
            in_index=(0, 1, 2, 3),
            input_transform='resize_concat',
            kernel_size=1,
            num_convs=1,
            concat_input=False,
            dropout_ratio=-1,
            num_classes=8,
            norm_cfg=dict(type='SyncBN', requires_grad=True),
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
        dict(
            type='OCRHead',
            in_channels=[18, 36, 72, 144],
            in_index=(0, 1, 2, 3),
            input_transform='resize_concat',
            channels=512,
            ocr_channels=256,
            dropout_ratio=-1,
            num_classes=8,
            norm_cfg=dict(type='SyncBN', requires_grad=True),
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
    ])
train_cfg = dict()
test_cfg = dict(mode='whole')
dataset_type = 'Aircraft'
data_root = '/mmdetection_aircraft/data/segm/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(1024, 768), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=(512, 384), cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size=(512, 384), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1024, 768),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=5,
    workers_per_gpu=2,
    train=dict(
        type='Aircraft',
        data_root='/mmdetection_aircraft/data/segm/',
        img_dir='JPEGImages',
        ann_dir='SegmentationClass',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(type='Resize', img_scale=(1024, 768), ratio_range=(0.5, 2.0)),
            dict(type='RandomCrop', crop_size=(512, 384), cat_max_ratio=0.75),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size=(512, 384), pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg'])
        ],
        split='train.txt'),
    val=dict(
        type='Aircraft',
        data_root='/mmdetection_aircraft/data/segm/',
        img_dir='JPEGImages',
        ann_dir='SegmentationClass',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1024, 768),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ],
        split='val.txt'),
    test=dict(
        type='Aircraft',
        data_root='/mmdetection_aircraft/data/segm/',
        img_dir='JPEGImages',
        ann_dir='SegmentationClass',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1024, 768),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ],
        split='val.txt'))
log_config = dict(
    interval=1, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = 'checkpoints/ocrnet_hr18_512x1024_40k_cityscapes_20200601_033320-401c5bdd.pth'
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
total_iters = 3
checkpoint_config = dict(by_epoch=False, interval=3)
evaluation = dict(interval=3, metric='mIoU')
work_dir = './work_dirs/tutorial'
seed = 0
gpu_ids = [0]

TRAIN MODEL :

model = build_segmentor(
    cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
model.CLASSES = datasets[0].CLASSES
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

#FULL error description:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-16-fec2661e1f4c> in <module>
     16 mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
     17 train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
---> 18                 meta=dict())

~/mmsegmentation/mmseg/apis/train.py in train_segmentor(model, dataset, cfg, distributed, validate, timestamp, meta)
    104     elif cfg.load_from:
    105         runner.load_checkpoint(cfg.load_from)
--> 106     runner.run(data_loaders, cfg.workflow, cfg.total_iters)

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py in run(self, data_loaders, workflow, max_iters, **kwargs)
    117                     if mode == 'train' and self.iter >= max_iters:
    118                         return
--> 119                     iter_runner(iter_loaders[i], **kwargs)
    120 
    121         time.sleep(1)  # wait for some hooks like loggers to finish

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py in train(self, data_loader, **kwargs)
     53         self.call_hook('before_train_iter')
     54         data_batch = next(data_loader)
---> 55         outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
     56         if not isinstance(outputs, dict):
     57             raise TypeError('model.train_step() must return a dict')

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py in train_step(self, *inputs, **kwargs)
     29 
     30         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
---> 31         return self.module.train_step(*inputs[0], **kwargs[0])
     32 
     33     def val_step(self, *inputs, **kwargs):

~/mmsegmentation/mmseg/models/segmentors/base.py in train_step(self, data_batch, optimizer, **kwargs)
    147                 averaging the logs.
    148         """
--> 149         losses = self.forward_train(**data_batch, **kwargs)
    150         loss, log_vars = self._parse_losses(losses)
    151 

~/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py in forward_train(self, img, img_metas, gt_semantic_seg)
    150         """
    151 
--> 152         x = self.extract_feat(img)
    153 
    154         losses = dict()

~/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py in extract_feat(self, img)
     76     def extract_feat(self, img):
     77         """Extract features from images."""
---> 78         x = self.backbone(img)
     79         if self.with_neck:
     80             x = self.neck(x)

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/mmsegmentation/mmseg/models/backbones/hrnet.py in forward(self, x)
    512 
    513         x = self.conv1(x)
--> 514         x = self.norm1(x)
    515         x = self.relu(x)
    516         x = self.conv2(x)

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py in forward(self, input)
    456             if self.process_group:
    457                 process_group = self.process_group
--> 458             world_size = torch.distributed.get_world_size(process_group)
    459             need_sync = world_size > 1
    460 

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py in get_world_size(group)
    584         return -1
    585 
--> 586     return _get_group_size(group)
    587 
    588 

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py in _get_group_size(group)
    200     """
    201     if group is GroupMember.WORLD:
--> 202         _check_default_pg()
    203         return _default_pg.size()
    204     if group not in _pg_group_ranks:

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py in _check_default_pg()
    191     """
    192     assert _default_pg is not None, \
--> 193         "Default process group is not initialized"
    194 
    195 

AssertionError: Default process group is not initialized
@xvjiarui
Copy link
Collaborator

Hi @rassabin
In your config, norm_cfg in backbone and heads is SyncBN, which requires distributed training.

@rassabin
Copy link
Author

Hi @rassabin
In your config, norm_cfg in backbone and heads is SyncBN, which requires distributed training.

Yeap, that helps. But it strange that we should to change norm_cfg parameter for each head seperatly as in backbone.

@rassabin
Copy link
Author

New error case by trying to use mask version of binary crossentropy.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-15-fec2661e1f4c> in <module>
     16 mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
     17 train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
---> 18                 meta=dict())

~/mmsegmentation/mmseg/apis/train.py in train_segmentor(model, dataset, cfg, distributed, validate, timestamp, meta)
    104     elif cfg.load_from:
    105         runner.load_checkpoint(cfg.load_from)
--> 106     runner.run(data_loaders, cfg.workflow, cfg.total_iters)

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py in run(self, data_loaders, workflow, max_iters, **kwargs)
    117                     if mode == 'train' and self.iter >= max_iters:
    118                         return
--> 119                     iter_runner(iter_loaders[i], **kwargs)
    120 
    121         time.sleep(1)  # wait for some hooks like loggers to finish

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py in train(self, data_loader, **kwargs)
     53         self.call_hook('before_train_iter')
     54         data_batch = next(data_loader)
---> 55         outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
     56         if not isinstance(outputs, dict):
     57             raise TypeError('model.train_step() must return a dict')

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py in train_step(self, *inputs, **kwargs)
     29 
     30         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
---> 31         return self.module.train_step(*inputs[0], **kwargs[0])
     32 
     33     def val_step(self, *inputs, **kwargs):

~/mmsegmentation/mmseg/models/segmentors/base.py in train_step(self, data_batch, optimizer, **kwargs)
    148         """
    149         data_batch['gt_semantic_seg'] = data_batch['gt_semantic_seg'][:,0,:].permute(0, 3, 1, 2)
--> 150         losses = self.forward_train(**data_batch, **kwargs)
    151         loss, log_vars = self._parse_losses(losses)
    152 

~/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py in forward_train(self, img, img_metas, gt_semantic_seg)
    155 
    156         loss_decode = self._decode_head_forward_train(x, img_metas,
--> 157                                                       gt_semantic_seg)
    158         losses.update(loss_decode)
    159 

~/mmsegmentation/mmseg/models/segmentors/cascade_encoder_decoder.py in _decode_head_forward_train(self, x, img_metas, gt_semantic_seg)
     84 
     85         loss_decode = self.decode_head[0].forward_train(
---> 86             x, img_metas, gt_semantic_seg, self.train_cfg)
     87 
     88         losses.update(add_prefix(loss_decode, 'decode_0'))

~/mmsegmentation/mmseg/models/decode_heads/decode_head.py in forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg)
    181         """
    182         seg_logits = self.forward(inputs)
--> 183         losses = self.losses(seg_logits, gt_semantic_seg)
    184         return losses
    185 

~/mmsegmentation/mmseg/models/decode_heads/decode_head.py in losses(self, seg_logit, seg_label)
    225             seg_label,
    226             weight=seg_weight,
--> 227             ignore_index=self.ignore_index)
    228         loss['acc_seg'] = accuracy(seg_logit, seg_label)
    229         return loss

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/mmsegmentation/mmseg/models/losses/cross_entropy_loss.py in forward(self, cls_score, label, weight, avg_factor, reduction_override, **kwargs)
    175             class_weight=class_weight,
    176             reduction=reduction,
--> 177             avg_factor=avg_factor)
    178         return loss_cls

~/mmsegmentation/mmseg/models/losses/cross_entropy_loss.py in mask_cross_entropy(pred, target, label, reduction, avg_factor, class_weight)
    114     pred_slice = pred[inds, label].squeeze(1)
    115     return F.binary_cross_entropy_with_logits(
--> 116         pred_slice, target, weight=class_weight, reduction='mean')[None]
    117 
    118 

~/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2122 
   2123     if not (target.size() == input.size()):
-> 2124         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2125 
   2126     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([3, 3, 512, 384])) must be the same as input size (torch.Size([3, 8, 512, 384]))

I understand that the reason is different number of channels in output of the model and input annotations. But i cannot find the way to load 8 channel mask. I have the script which convert RGB representation in 8 class dimenasion mask, but where i should put it? In default

class CustomDataset(Dataset)         

input only the path to mask files and reading it occurs in "build_from_cfg" module of mmcv. Thanks. Any suggestions

@xvjiarui
Copy link
Collaborator

Hi @rassabin

  1. For the mask version of binary crossentropy you mentioned, are you referring to use use_mask=True?
  2. For the 8 channels mask, are you referring to load segmentation map of shape (8, H, W)? Is it a binary mask of 8 classes?

@rassabin
Copy link
Author

Hi @rassabin

  1. For the mask version of binary crossentropy you mentioned, are you referring to use use_mask=True?
  2. For the 8 channels mask, are you referring to load segmentation map of shape (8, H, W)? Is it a binary mask of 8 classes?
  1. Yes.
  2. I have the RGB mask in 'png' files , applying the converting script perform it into array with shape (8, H, W), etc yes binary mask of 8 classes.

@xvjiarui
Copy link
Collaborator

Hi @rassabin
I suggest you convert RGB mask into the format the same as the standard dataset.
In total there are 8 colors in your dataset. Then you may convert it into png files of P mode.
For example,

# RGB color list of length 8
color_list = [[0, 255, 0], ..., [255, 0, 255]]
palette = np.array(color_list)
# create a new image 
image = Image.open(img_path)
# convert to `P` mode
new_image = image.quantize(palette=palette)
new_image.save(new_img_path)

@rassabin
Copy link
Author

Hi @rassabin
I suggest you convert RGB mask into the format the same as the standard dataset.
In total there are 8 colors in your dataset. Then you may convert it into png files of P mode.
For example,

# RGB color list of length 8
color_list = [[0, 255, 0], ..., [255, 0, 255]]
palette = np.array(color_list)
# create a new image 
image = Image.open(img_path)
# convert to `P` mode
new_image = image.quantize(palette=palette)
new_image.save(new_img_path)

It's not a problem to create new representation of image, the problem is that NN have 8 channels of output and in mask binary crossentropy it's compared (8 , H , W) output with (3, H, W) label.

@xvjiarui
Copy link
Collaborator

Hi @rassabin
For datasets like cityscapes, the ground truth segmentation maps are of shape (H, W). Each pixel value range from 0 to num_classes-1.
So there is no need for mask version of cross-entropy loss.

@rassabin
Copy link
Author

Hi @rassabin
For datasets like cityscapes, the ground truth segmentation maps are of shape (H, W). Each pixel value range from 0 to num_classes-1.
So there is no need for mask version of cross-entropy loss.

Ok, i got it, thank you. Btw on such moment we the mask binary crossentropy loss cannot by applyiable ?

@xvjiarui
Copy link
Collaborator

Hi @rassabin
It is not used by any models yet. We reserved it for potential usage in the future.

@rubeea
Copy link

rubeea commented Nov 21, 2020

Hi @rassabin
In your config, norm_cfg in backbone and heads is SyncBN, which requires distributed training.

Can you pls specify how to solve this problem? Thanks in advance

@NingAnMe
Copy link

Hi @rassabin
In your config, norm_cfg in backbone and heads is SyncBN, which requires distributed training.

Can you pls specify how to solve this problem? Thanks in advance

change "SyncBN" to "BN" in "configs/base"

@675492062
Copy link

Hi @rassabin
In your config, norm_cfg in backbone and heads is SyncBN, which requires distributed training.

Can you pls specify how to solve this problem? Thanks in advance

change "SyncBN" to "BN" in "configs/base"

For single GPU, we removed this error by changing "SyncBN" to "BN"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants