Skip to content

Commit

Permalink
Merge pull request #518 from CharlesAuthier/feature/eta_tqdm
Browse files Browse the repository at this point in the history
adding total to tqdm
  • Loading branch information
CharlesAuthier committed Jun 13, 2023
2 parents 22e7f15 + 2bb5c93 commit 3d5d709
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ def segmentation(param,
fp = np.memmap(tp_mem, dtype='float16', mode='w+', shape=(tf_len, h_padded, w_padded, num_classes))
img_gen = gen_img_samples(src=input_image, patch_list=patch_list, chunk_size=chunk_size)
single_class_mode = False if num_classes > 1 else True
for sub_image, h_idxs, w_idxs, hann_win in tqdm(img_gen, position=0, leave=True,
desc=f'Inferring on patches'):
for sub_image, h_idxs, w_idxs, hann_win in tqdm(
img_gen, position=0, leave=True, desc='Inferring on patches',
total=len(patch_list)
):
hann_win = np.expand_dims(hann_win, -1)
image_metadata = add_metadata_from_raster_to_sample(sat_img_arr=sub_image,
raster_handle=input_image,
Expand Down Expand Up @@ -285,7 +287,7 @@ def override_model_params_from_checkpoint(

if model_ckpt != params.model or classes_ckpt != classes or bands_ckpt != bands \
or clip_limit != clip_limit_ckpt:
logging.info(f"\nParameters from checkpoint will override inputted parameters."
logging.info("\nParameters from checkpoint will override inputted parameters."
f"\n\t\t\t Inputted | Overriden"
f"\nModel:\t\t {params.model} | {model_ckpt}"
f"\nInput bands:\t\t{bands} | {bands_ckpt}"
Expand Down
2 changes: 1 addition & 1 deletion tiling_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def main(cfg: DictConfig) -> None:
# TODO: use mp.Manager() to modify aoi.tiling_pairs_list from within tiling_per_aoi
tiler.src_aoi_list = []
for tiled_aoi, rs_tiler_paths, vec_tiler_paths in tqdm(
tilers, desc=f"Updating AOIs' information about their patches paths"):
tilers, desc="Updating AOIs' information about their patches paths"):
tiled_aoi.patches_pairs_list = [(rs_ptch, gt_ptch) for rs_ptch, gt_ptch in zip(rs_tiler_paths, vec_tiler_paths)]
tiler.src_aoi_list.append(tiled_aoi)

Expand Down

0 comments on commit 3d5d709

Please sign in to comment.