Skip to content

Commit

Permalink
[Doc] Polish predict doc and error message about image_path (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
LutaoChu authored Nov 8, 2021
1 parent aa1c7d1 commit 0166716
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
6 changes: 2 additions & 4 deletions docs/predict/predict.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python predict.py \
--save_dir output/result
```

Among them, `image_path` can also be a directory. At this time, all the images in the directory will be predicted and the visualization results will be saved.
Where `image_path` can be the path of a picture, a file list containing image paths, or a directory. At this time, that image or all images in file list or directory will be predicted and the visualization results will be saved.

Similarly, you can use `--aug_pred` to turn on multi-scale flip prediction, and `--is_slide` to turn on sliding window prediction.

Expand All @@ -35,10 +35,8 @@ Similarly, you can use `--aug_pred` to turn on multi-scale flip prediction, and
...
```
* At this point, you can specify `image_list` as `train.txt` and `image_dir` as the directory where the training data is located when predicting. The robustness of PaddleSeg allows you to do this, and the output will be the prediction result of the **original training data**.
## 2.API
Parameter Analysis of Forecast API
Parameter Analysis of Forecast API
```
paddleseg.core.predict(
Expand Down
3 changes: 1 addition & 2 deletions docs/predict/predict_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python predict.py \
--save_dir output/result
```

其中`image_path`也可以是一个目录,这时候将对目录内的所有图片进行预测并保存可视化结果图
其中`image_path`可以是一张图片的路径,也可以是一个包含图片路径的文件列表,也可以是一个目录,这时候将对该图片或文件列表或目录内的所有图片进行预测并保存可视化结果图

同样的,可以通过`--aug_pred`开启多尺度翻转预测, `--is_slide`开启滑窗预测。

Expand All @@ -34,7 +34,6 @@ python predict.py \
...
```
* 此时你可以在预测时将`image_list`指定为`train.txt`,将`image_dir`指定为训练数据所在的目录。PaddleSeg的鲁棒性允许你这样做,输出的结果将是对**原始训练数据**的预测结果。
## 2.预测函数API
预测API的参数解析
Expand Down
13 changes: 7 additions & 6 deletions paddleseg/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def load_entire_model(model, pretrained):
logger.warning('Not all pretrained params of {} are loaded, ' \
'training from scratch or a pretrained backbone.'.format(model.__class__.__name__))


def download_pretrained_model(pretrained_model):
"""
Download pretrained model from url.
Expand All @@ -59,17 +60,16 @@ def download_pretrained_model(pretrained_model):
savename = savename.split('.')[0]

with generate_tempdir() as _dir:
with filelock.FileLock(
os.path.join(seg_env.TMP_HOME, savename)):
with filelock.FileLock(os.path.join(seg_env.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=seg_env.PRETRAINED_MODEL_HOME,
extraname=savename)
pretrained_model = os.path.join(pretrained_model,
'model.pdparams')
pretrained_model = os.path.join(pretrained_model, 'model.pdparams')
return pretrained_model


def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logger.info('Loading pretrained model from {}'.format(pretrained_model))
Expand Down Expand Up @@ -165,10 +165,11 @@ def get_image_list(image_path):
image_list.append(os.path.join(root, f))
else:
raise FileNotFoundError(
'`--image_path` is not found. it should be an image file or a directory including images'
'`--image_path` is not found. it should be a path of image, or a file list containing image paths, or a directory including images.'
)

if len(image_list) == 0:
raise RuntimeError('There are not image file in `--image_path`')
raise RuntimeError(
'There are not image file in `--image_path`={}'.format(image_path))

return image_list, image_dir
3 changes: 1 addition & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_args():
'--image_path',
dest='image_path',
help=
'The path of image, it can be a file or a directory including images',
'The image to predict, which can be a path of image, or a file list containing image paths, or a directory including images',
type=str,
default=None)
parser.add_argument(
Expand Down Expand Up @@ -107,7 +107,6 @@ def parse_args():
return parser.parse_args()



def get_test_config(cfg, args):

test_config = cfg.test_config
Expand Down

0 comments on commit 0166716

Please sign in to comment.