Skip to content

Commit

Permalink
Fixed slide inference (open-mmlab#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
xvjiarui committed Aug 25, 2020
1 parent 9e63ddb commit 8c0e093
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
python tools/train.py ${CONFIG_FILE} [optional arguments]
```

If you want to specify the working directory in the command, you can add an argument `--work_dir ${YOUR_WORK_DIR}`.
If you want to specify the working directory in the command, you can add an argument `--work-dir ${YOUR_WORK_DIR}`.

### Train with multiple GPUs

Expand All @@ -253,7 +253,7 @@ Difference between `resume-from` and `load-from`:
If you run MMSegmentation on a cluster managed with [slurm](https://slurm.schedmd.com/), you can use the script `slurm_train.sh`. (This script also supports single machine training.)

```shell
[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} --work-dir ${WORK_DIR}
```

Here is an example of using 16 GPUs to train PSPNet on the dev partition.
Expand Down
2 changes: 1 addition & 1 deletion mmseg/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=not show, **data)
result = model(return_loss=False, **data)
if isinstance(results, list):
results.extend(result)
else:
Expand Down
2 changes: 1 addition & 1 deletion mmseg/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install citscapesscripts" to '
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if logger is None:
Expand Down
6 changes: 4 additions & 2 deletions mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def slide_inference(self, img, img_meta, rescale):

count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
# We want to regard count_mat as a constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.detach().numpy())
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(
count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
preds = resize(
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(
Expand Down

0 comments on commit 8c0e093

Please sign in to comment.