diff --git a/docs/predict/predict.md b/docs/predict/predict.md index 48df80a49a..d9c712f377 100644 --- a/docs/predict/predict.md +++ b/docs/predict/predict.md @@ -14,7 +14,7 @@ python predict.py \ --save_dir output/result ``` -Among them, `image_path` can also be a directory. At this time, all the images in the directory will be predicted and the visualization results will be saved. +Where `image_path` can be the path of a picture, a file list containing image paths, or a directory. At this time, that image or all images in file list or directory will be predicted and the visualization results will be saved. Similarly, you can use `--aug_pred` to turn on multi-scale flip prediction, and `--is_slide` to turn on sliding window prediction. @@ -35,10 +35,8 @@ Similarly, you can use `--aug_pred` to turn on multi-scale flip prediction, and ... ``` -* At this point, you can specify `image_list` as `train.txt` and `image_dir` as the directory where the training data is located when predicting. The robustness of PaddleSeg allows you to do this, and the output will be the prediction result of the **original training data**. - ## 2.API -Parameter Analysis of Forecast API +Parameter Analysis of Forecast API ``` paddleseg.core.predict( diff --git a/docs/predict/predict_cn.md b/docs/predict/predict_cn.md index 27d447afd4..815eb69ea1 100644 --- a/docs/predict/predict_cn.md +++ b/docs/predict/predict_cn.md @@ -13,7 +13,7 @@ python predict.py \ --save_dir output/result ``` -其中`image_path`也可以是一个目录,这时候将对目录内的所有图片进行预测并保存可视化结果图。 +其中`image_path`可以是一张图片的路径,也可以是一个包含图片路径的文件列表,也可以是一个目录,这时候将对该图片或文件列表或目录内的所有图片进行预测并保存可视化结果图。 同样的,可以通过`--aug_pred`开启多尺度翻转预测, `--is_slide`开启滑窗预测。 @@ -34,7 +34,6 @@ python predict.py \ ... ``` -* 此时你可以在预测时将`image_list`指定为`train.txt`,将`image_dir`指定为训练数据所在的目录。PaddleSeg的鲁棒性允许你这样做,输出的结果将是对**原始训练数据**的预测结果。 ## 2.预测函数API 预测API的参数解析 diff --git a/paddleseg/utils/utils.py b/paddleseg/utils/utils.py index 6b33442f49..8652c6a56e 100644 --- a/paddleseg/utils/utils.py +++ b/paddleseg/utils/utils.py @@ -41,6 +41,7 @@ def load_entire_model(model, pretrained): logger.warning('Not all pretrained params of {} are loaded, ' \ 'training from scratch or a pretrained backbone.'.format(model.__class__.__name__)) + def download_pretrained_model(pretrained_model): """ Download pretrained model from url. @@ -59,17 +60,16 @@ def download_pretrained_model(pretrained_model): savename = savename.split('.')[0] with generate_tempdir() as _dir: - with filelock.FileLock( - os.path.join(seg_env.TMP_HOME, savename)): + with filelock.FileLock(os.path.join(seg_env.TMP_HOME, savename)): pretrained_model = download_file_and_uncompress( pretrained_model, savepath=_dir, extrapath=seg_env.PRETRAINED_MODEL_HOME, extraname=savename) - pretrained_model = os.path.join(pretrained_model, - 'model.pdparams') + pretrained_model = os.path.join(pretrained_model, 'model.pdparams') return pretrained_model + def load_pretrained_model(model, pretrained_model): if pretrained_model is not None: logger.info('Loading pretrained model from {}'.format(pretrained_model)) @@ -165,10 +165,11 @@ def get_image_list(image_path): image_list.append(os.path.join(root, f)) else: raise FileNotFoundError( - '`--image_path` is not found. it should be an image file or a directory including images' + '`--image_path` is not found. it should be a path of image, or a file list containing image paths, or a directory including images.' ) if len(image_list) == 0: - raise RuntimeError('There are not image file in `--image_path`') + raise RuntimeError( + 'There are not image file in `--image_path`={}'.format(image_path)) return image_list, image_dir diff --git a/predict.py b/predict.py index 8835140a9c..f45724d2cb 100644 --- a/predict.py +++ b/predict.py @@ -38,7 +38,7 @@ def parse_args(): '--image_path', dest='image_path', help= - 'The path of image, it can be a file or a directory including images', + 'The image to predict, which can be a path of image, or a file list containing image paths, or a directory including images', type=str, default=None) parser.add_argument( @@ -107,7 +107,6 @@ def parse_args(): return parser.parse_args() - def get_test_config(cfg, args): test_config = cfg.test_config