From fa1c1af79ef5387434f2a76744f75b5aaca09f0b Mon Sep 17 00:00:00 2001 From: Han123su <107395380+Han123su@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:16:29 +0800 Subject: [PATCH] Fix RandomWeightedCrop for Integer Weightmap Handling (#8097) Fixes #7949 . ### Description Regardless of the type of `weight map`, random numbers should be kept as floating-point numbers for calculating the sampling location. However, `searchsorted` requires matching data structures. I have modified `convert_to_dst_type` to control converting only the data structure while maintaining the original data type. Additionally, I have included an example with integer weight maps in the test file. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han123su Signed-off-by: Han123su <107395380+Han123su@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/transforms/utils.py | 3 ++- tests/test_rand_weighted_crop.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 32fffc25f0..e7e1616e13 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -582,7 +582,8 @@ def weighted_patch_samples( if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) else: - r, *_ = convert_to_dst_type(r_state.random(n_samples), v) + r_samples = r_state.random(n_samples) + r, *_ = convert_to_dst_type(r_samples, v, dtype=r_samples.dtype) idx = searchsorted(v, r * v[-1], right=True) # type: ignore idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore # compensate 'valid' mode diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 47a8f3bfa2..f509065a56 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -90,6 +90,21 @@ def get_data(ndim): [[63, 37], [31, 43], [66, 20]], ] ) + im = SEG1_2D + weight_map = np.zeros_like(im, dtype=np.int32) + weight_map[0, 30, 20] = 3 + weight_map[0, 45, 44] = 1 + weight_map[0, 60, 50] = 2 + TESTS.append( + [ + "int w 2d", + dict(spatial_size=(10, 12), num_samples=3), + p(im), + q(weight_map), + (1, 10, 12), + [[60, 50], [30, 20], [45, 44]], + ] + ) im = SEG1_3D weight = np.zeros_like(im) weight[0, 5, 30, 17] = 1.1 @@ -149,6 +164,21 @@ def get_data(ndim): [[32, 24, 40], [32, 24, 40], [32, 24, 40]], ] ) + im = SEG1_3D + weight_map = np.zeros_like(im, dtype=np.int32) + weight_map[0, 6, 22, 19] = 4 + weight_map[0, 8, 40, 31] = 2 + weight_map[0, 13, 20, 24] = 3 + TESTS.append( + [ + "int w 3d", + dict(spatial_size=(8, 10, 12), num_samples=3), + p(im), + q(weight_map), + (1, 8, 10, 12), + [[13, 20, 24], [6, 22, 19], [8, 40, 31]], + ] + ) class TestRandWeightedCrop(CropTest):