Skip to content

Commit

Permalink
MIL component to extract patches (#3237)
Browse files Browse the repository at this point in the history
* MIL component to extract patches

Signed-off-by: myron <amyronenko@nvidia.com>

* MIL component to extract patches

Signed-off-by: myron <amyronenko@nvidia.com>

* random flag, minor fixes

Signed-off-by: myron <amyronenko@nvidia.com>

* minor fixes for padding

Signed-off-by: myron <amyronenko@nvidia.com>

* improve tests

Signed-off-by: myron <amyronenko@nvidia.com>

Co-authored-by: Behrooz <3968947+drbeh@users.noreply.github.com>
  • Loading branch information
myron and bhashemian authored Nov 15, 2021
1 parent 4d83fc0 commit 9ec1d14
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 11 deletions.
4 changes: 2 additions & 2 deletions monai/apps/pathology/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .spatial.array import SplitOnGrid
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
from .spatial.array import SplitOnGrid, TileOnGrid
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
from .stain.array import ExtractHEStains, NormalizeHEStains
from .stain.dictionary import (
ExtractHEStainsd,
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/pathology/transforms/spatial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .array import SplitOnGrid
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
from .array import SplitOnGrid, TileOnGrid
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
158 changes: 155 additions & 3 deletions monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
from numpy.lib.stride_tricks import as_strided

from monai.transforms.transform import Transform
from monai.transforms.transform import Randomizable, Transform

__all__ = ["SplitOnGrid"]
__all__ = ["SplitOnGrid", "TileOnGrid"]


class SplitOnGrid(Transform):
Expand Down Expand Up @@ -73,3 +75,153 @@ def get_params(self, image_size):
)

return patch_size, steps


class TileOnGrid(Randomizable, Transform):
"""
Tile the 2D image into patches on a grid and maintain a subset of it.
This transform works only with np.ndarray inputs for 2D images.
Args:
tile_count: number of tiles to extract, if None extracts all non-background tiles
Defaults to ``None``.
tile_size: size of the square tile
Defaults to ``256``.
step: step size
Defaults to ``None`` (same as tile_size)
random_offset: Randomize position of the grid, instead of starting from the top-left corner
Defaults to ``False``.
pad_full: pad image to the size evenly divisible by tile_size
Defaults to ``False``.
background_val: the background constant (e.g. 255 for white background)
Defaults to ``255``.
filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size,
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset
Defaults to ``min`` (which assumes background is high value)
"""

def __init__(
self,
tile_count: Optional[int] = None,
tile_size: int = 256,
step: Optional[int] = None,
random_offset: bool = False,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
):
self.tile_count = tile_count
self.tile_size = tile_size
self.step = step
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode

if self.step is None:
self.step = self.tile_size # non-overlapping grid

self.offset = (0, 0)
self.random_idxs = np.array((0,))

if self.filter_mode not in ["min", "max", "random"]:
raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode))

def randomize(self, img_size: Sequence[int]) -> None:

c, h, w = img_size
tile_step = cast(int, self.step)

self.offset = (0, 0)
if self.random_offset:
pad_h = h % self.tile_size
pad_w = w % self.tile_size
self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0)
h = h - self.offset[0]
w = w - self.offset[1]

if self.pad_full:
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
h = h + pad_h
w = w + pad_w

h_n = (h - self.tile_size + tile_step) // tile_step
w_n = (w - self.tile_size + tile_step) // tile_step
tile_total = h_n * w_n

if self.tile_count is not None and tile_total > self.tile_count:
self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False)
else:
self.random_idxs = np.array((0,))

def __call__(self, image: np.ndarray) -> np.ndarray:

# add random offset
self.randomize(img_size=image.shape)
tile_step = cast(int, self.step)

if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
image = image[:, self.offset[0] :, self.offset[1] :]

# pad to full size, divisible by tile_size
if self.pad_full:
c, h, w = image.shape
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
image = np.pad(
image,
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
constant_values=self.background_val,
)

# extact tiles
xstep, ystep = tile_step, tile_step
xsize, ysize = self.tile_size, self.tile_size
clen, xlen, ylen = image.shape
cstride, xstride, ystride = image.strides
llw = as_strided(
image,
shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize),
strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride),
writeable=False,
)
image = llw.reshape(-1, clen, xsize, ysize)

# if keeping all patches
if self.tile_count is None:
# retain only patches with significant foreground content to speed up inference
# FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference
thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
if self.filter_mode == "min":
# default, keep non-background tiles (small values)
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh)
image = image[idxs.reshape(-1)]
elif self.filter_mode == "max":
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh)
image = image[idxs.reshape(-1)]

else:
if len(image) > self.tile_count:

if self.filter_mode == "min":
# default, keep non-background tiles (smallest values)
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count]
image = image[idxs]
elif self.filter_mode == "max":
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :]
image = image[idxs]
else:
# random subset (more appropriate for WSIs without distinct background)
if self.random_idxs is not None:
image = image[self.random_idxs]

elif len(image) < self.tile_count:
image = np.pad(
image,
[[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]],
constant_values=self.background_val,
)

return image
84 changes: 80 additions & 4 deletions monai/apps/pathology/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Hashable, Mapping, Optional, Tuple, Union
import copy
from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch

from monai.config import KeysCollection
from monai.transforms.transform import MapTransform
from monai.transforms.transform import MapTransform, Randomizable

from .array import SplitOnGrid
from .array import SplitOnGrid, TileOnGrid

__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"]


class SplitOnGridd(MapTransform):
Expand Down Expand Up @@ -53,4 +55,78 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
return d


class TileOnGridd(Randomizable, MapTransform):
"""
Tile the 2D image into patches on a grid and maintain a subset of it.
This transform works only with np.ndarray inputs for 2D images.
Args:
tile_count: number of tiles to extract, if None extracts all non-background tiles
Defaults to ``None``.
tile_size: size of the square tile
Defaults to ``256``.
step: step size
Defaults to ``None`` (same as tile_size)
random_offset: Randomize position of the grid, instead of starting from the top-left corner
Defaults to ``False``.
pad_full: pad image to the size evenly divisible by tile_size
Defaults to ``False``.
background_val: the background constant (e.g. 255 for white background)
Defaults to ``255``.
filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size,
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset
Defaults to ``min`` (which assumes background is high value)
"""

def __init__(
self,
keys: KeysCollection,
tile_count: Optional[int] = None,
tile_size: int = 256,
step: Optional[int] = None,
random_offset: bool = False,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
allow_missing_keys: bool = False,
return_list_of_dicts: bool = False,
):
super().__init__(keys, allow_missing_keys)

self.return_list_of_dicts = return_list_of_dicts
self.seed = None

self.splitter = TileOnGrid(
tile_count=tile_count,
tile_size=tile_size,
step=step,
random_offset=random_offset,
pad_full=pad_full,
background_val=background_val,
filter_mode=filter_mode,
)

def randomize(self, data: Any = None) -> None:
self.seed = self.R.randint(10000) # type: ignore

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]:

self.randomize()

d = dict(data)
for key in self.key_iterator(d):
self.splitter.set_random_state(seed=self.seed) # same random seed for all keys
d[key] = self.splitter(d[key])

if self.return_list_of_dicts:
d_list = []
for i in range(len(d[self.keys[0]])):
d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()})
d = d_list # type: ignore

return d


SplitOnGridDict = SplitOnGridD = SplitOnGridd
TileOnGridDict = TileOnGridD = TileOnGridd
Loading

0 comments on commit 9ec1d14

Please sign in to comment.