diff --git a/lazyslide/cv_mods/__init__.py b/lazyslide/cv_mods/__init__.py index db73628..84cbf1c 100644 --- a/lazyslide/cv_mods/__init__.py +++ b/lazyslide/cv_mods/__init__.py @@ -1,5 +1,11 @@ -from .mods import (ConvertColorspace, - MedianBlur, GaussianBlur, BoxBlur, - BinaryThreshold, MorphOpen, MorphClose, - ForegroundDetection, TissueDetectionHE - ) +from .mods import ( + ConvertColorspace, + MedianBlur, + GaussianBlur, + BoxBlur, + BinaryThreshold, + MorphOpen, + MorphClose, + ForegroundDetection, + TissueDetectionHE, +) diff --git a/lazyslide/cv_mods/mods.py b/lazyslide/cv_mods/mods.py index 227ee3c..e00c53b 100644 --- a/lazyslide/cv_mods/mods.py +++ b/lazyslide/cv_mods/mods.py @@ -6,8 +6,12 @@ class ConvertColorspace(Transform): - - def __init__(self, code=None, old=None, new=None, ): + def __init__( + self, + code=None, + old=None, + new=None, + ): if code is None: self.old = old.upper() self.new = new.upper() @@ -123,7 +127,7 @@ def __repr__(self): def apply(self, image): assert image.dtype == np.uint8, f"image dtype {image.dtype} must be np.uint8" assert ( - image.ndim == 2 + image.ndim == 2 ), f"input image has shape {image.shape}. Must convert to 1-channel image (H, W)." _, out = cv2.threshold( src=image, @@ -219,11 +223,11 @@ class ForegroundDetection(Transform): """ def __init__( - self, - mask_name=None, - min_region_size=5000, - max_hole_size=1500, - outer_contours_only=False, + self, + mask_name=None, + min_region_size=5000, + max_hole_size=1500, + outer_contours_only=False, ): self.min_region_size = min_region_size self.max_hole_size = max_hole_size @@ -286,12 +290,12 @@ def apply(self, mask): # loop thru contours for ( - cnt, - outside, - size_thresh, - hole, - hole_size_thresh, - hole_parent_thresh, + cnt, + outside, + size_thresh, + hole, + hole_size_thresh, + hole_parent_thresh, ) in zip( contours, outside_contours, @@ -331,11 +335,11 @@ class ForegroundContourDetection(Transform): """ def __init__( - self, - mask_name=None, - min_region_size=5000, - max_hole_size=1500, - outer_contours_only=False, + self, + mask_name=None, + min_region_size=5000, + max_hole_size=1500, + outer_contours_only=False, ): self.min_region_size = min_region_size self.max_hole_size = max_hole_size @@ -374,16 +378,22 @@ def apply(self, mask): # outside contours must be above min_tissue_region_size threshold tissue_contours = outmost[ - contours_areas[outmost_slice] > self.min_region_size] + contours_areas[outmost_slice] > self.min_region_size + ] tissue_holes = holes[ # hole contours must be above area threshold - contours_areas[hole_slice] < self.max_hole_size & \ + contours_areas[hole_slice] + < self.max_hole_size + & # holes must have parents above area threshold - (contours_areas[hierarchy[hole_slice, 3]] > self.min_region_size)] + (contours_areas[hierarchy[hole_slice, 3]] > self.min_region_size) + ] - return ([np.squeeze(contours[ix], axis=1) for ix in tissue_contours], - [np.squeeze(contours[ix], axis=1) for ix in tissue_holes]) + return ( + [np.squeeze(contours[ix], axis=1) for ix in tissue_contours], + [np.squeeze(contours[ix], axis=1) for ix in tissue_holes], + ) class TissueDetectionHE(Transform): @@ -406,16 +416,16 @@ class TissueDetectionHE(Transform): """ def __init__( - self, - use_saturation=True, - blur_ksize=17, - threshold=7, - morph_n_iter=3, - morph_k_size=7, - min_region_size=2500, - max_hole_size=100, - outer_contours_only=False, - return_contours=False, + self, + use_saturation=True, + blur_ksize=17, + threshold=7, + morph_n_iter=3, + morph_k_size=7, + min_region_size=2500, + max_hole_size=100, + outer_contours_only=False, + return_contours=False, ): self.use_sat = use_saturation self.blur_ksize = blur_ksize @@ -447,12 +457,8 @@ def __init__( self.pipeline = [ MedianBlur(kernel_size=self.blur_ksize), thresholder, - MorphOpen( - kernel_size=self.morph_k_size, - n_iterations=self.morph_n_iter), - MorphClose( - kernel_size=self.morph_k_size, - n_iterations=self.morph_n_iter), + MorphOpen(kernel_size=self.morph_k_size, n_iterations=self.morph_n_iter), + MorphClose(kernel_size=self.morph_k_size, n_iterations=self.morph_n_iter), foreground, ] @@ -466,7 +472,7 @@ def __repr__(self): def apply(self, image): assert ( - image.dtype == np.uint8 + image.dtype == np.uint8 ), f"Input image dtype {image.dtype} must be np.uint8" # first get single channel image_ref if self.use_sat: diff --git a/lazyslide/h5.py b/lazyslide/h5.py index 52ccac9..f7b9962 100644 --- a/lazyslide/h5.py +++ b/lazyslide/h5.py @@ -60,7 +60,7 @@ def load(self): for mask_name in masks_group.keys(): ds = masks_group.get(mask_name) masks[mask_name] = ds[:] - masks_level[mask_name] = ds.attrs['level'] + masks_level[mask_name] = ds.attrs["level"] self.masks = masks self.masks_level = masks_level @@ -115,8 +115,7 @@ def save(self): if self.COORDS_KEY in h5: del h5[self.COORDS_KEY] - ds = h5.create_dataset(self.COORDS_KEY, data=self.coords, - chunks=True) + ds = h5.create_dataset(self.COORDS_KEY, data=self.coords, chunks=True) attrs = ds.attrs for k, v in asdict(self.tile_ops).items(): if v is None: @@ -130,9 +129,11 @@ def save(self): masks_group = h5.create_group(self.MASKS_KEY) for mask_name, mask_array in self.masks.items(): - ds = masks_group.create_dataset(mask_name, data=mask_array, chunks=True) + ds = masks_group.create_dataset( + mask_name, data=mask_array, chunks=True + ) attrs = ds.attrs - attrs['level'] = self.masks_level[mask_name] + attrs["level"] = self.masks_level[mask_name] if self._rewrite_contours: if self.CONTOURS_KEY in h5: diff --git a/lazyslide/loader/dataset.py b/lazyslide/loader/dataset.py index 0b15eb9..8871930 100644 --- a/lazyslide/loader/dataset.py +++ b/lazyslide/loader/dataset.py @@ -1,17 +1,16 @@ import torch from torch.utils.data import Dataset -from torchvision.transforms.v2 import (ToImage, ToDtype, Normalize, - Compose, Resize) +from torchvision.transforms.v2 import ToImage, ToDtype, Normalize, Compose, Resize from lazyslide.normalizer import ColorNormalizer -def compose_transform(resize=None, - antialias=False, - color_normalize=None, - feature_extraction=False, - - ): +def compose_transform( + resize=None, + antialias=False, + color_normalize=None, + feature_extraction=False, +): if feature_extraction: mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) @@ -19,7 +18,11 @@ def compose_transform(resize=None, mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) pre = [] - after = [ToImage(), ToDtype(dtype=torch.float32, scale=True), Normalize(mean=mean, std=std)] + after = [ + ToImage(), + ToDtype(dtype=torch.float32, scale=True), + Normalize(mean=mean, std=std), + ] if resize is not None: pre += [ToImage(), Resize(size=resize, antialias=antialias)] if color_normalize is not None: @@ -29,18 +32,17 @@ def compose_transform(resize=None, class FeatureExtractionDataset(Dataset): - - def __init__(self, - wsi, - transform=None, - resize=None, - antialias=False, - color_normalize=None, - ): + def __init__( + self, + wsi, + transform=None, + resize=None, + antialias=False, + color_normalize=None, + ): self.wsi = wsi if not wsi.has_tiles: - raise ValueError("WSI does not have tiles. " - "Please create tiles first.") + raise ValueError("WSI does not have tiles. " "Please create tiles first.") self.tiles_coords = self.wsi.h5_file.get_coords() self.tile_ops = self.wsi.h5_file.get_tile_ops() if transform is not None: @@ -52,10 +54,12 @@ def __init__(self, resize_to = (int(self.tile_ops.height), int(self.tile_ops.width)) else: resize_to = None - self.transform = compose_transform(resize=resize_to, - antialias=antialias, - color_normalize=color_normalize, - feature_extraction=True) + self.transform = compose_transform( + resize=resize_to, + antialias=antialias, + color_normalize=color_normalize, + feature_extraction=True, + ) def __len__(self): return len(self.tiles_coords) @@ -63,8 +67,7 @@ def __len__(self): def __getitem__(self, idx): coords = self.tiles_coords[idx] x, y = coords - image = self.wsi.get_patch(y, x, - self.tile_ops.ops_width, - self.tile_ops.ops_height, - self.tile_ops.level) + image = self.wsi.get_patch( + y, x, self.tile_ops.ops_width, self.tile_ops.ops_height, self.tile_ops.level + ) return self.transform(image) diff --git a/lazyslide/loader/slides_balanced_loader.py b/lazyslide/loader/slides_balanced_loader.py index e0b4a1b..03817f1 100644 --- a/lazyslide/loader/slides_balanced_loader.py +++ b/lazyslide/loader/slides_balanced_loader.py @@ -11,50 +11,39 @@ class Slide: - - def __init__(self, - n_tiles, - start_index, - end_index, - shuffle=True, - seed=0 - ): + def __init__( + self, + n_tiles, + start_index, + ): self.n_tiles = n_tiles self.start_index = start_index - self.end_index = end_index - if shuffle: - rng = np.random.default_rng(seed) - self.pool = rng.choice( - np.arange(start_index, end_index), n_tiles, - # Each index can only be sample once - replace=False) - else: - self.pool = np.arange(start_index, start_index + n_tiles) - self.pool = deque(self.pool) def __len__(self): return self.n_tiles def get_tile(self): - if len(self.pool) == 0: + if self.n_tiles == 0: return None - return self.pool.pop() + self.n_tiles -= 1 + return self.start_index + self.n_tiles class SlidesDataset(Dataset): # TODO: Allow both tile labels or slide labels to be passed in - def __init__(self, - wsi_list, - resize=None, - antialias=False, - color_normalize=None, - transform=None, - max_taken=None, - shuffle_slides=True, - shuffle_tiles=True, - seed=0 - ): + def __init__( + self, + wsi_list, + resize=None, + antialias=False, + color_normalize=None, + transform=None, + max_taken=None, + shuffle_slides=True, + shuffle_tiles=True, + seed=0, + ): try: from ncls import NCLS except ImportError: @@ -74,7 +63,9 @@ def __init__(self, if wsi.tile_ops.downsample != 1: resize_to = (int(wsi.tile_ops.height), int(wsi.tile_ops.width)) self.resize_transform.append( - Compose([ToImage(), Resize(size=resize_to, antialias=antialias)]) + Compose( + [ToImage(), Resize(size=resize_to, antialias=antialias)] + ) ) else: self.resize_transform.append(None) @@ -86,9 +77,13 @@ def __init__(self, rng.shuffle(self.proxy_ix) self.seed = seed - self.shuffle_tiles = shuffle_tiles self.wsi_list = wsi_list - self.wsi_n_tiles = [len(wsi_list[i].tiles_coords) for i in self.proxy_ix] + self.wsi_n_tiles = [] + for i in self.proxy_ix: + wsi = wsi_list[i] + if shuffle_tiles: + wsi.shuffle_tiles(seed) + self.wsi_n_tiles.append(len(wsi.tiles_coords)) self.ix_slides = np.insert(np.cumsum(self.wsi_n_tiles), 0, 0) self.ixs = [] @@ -119,8 +114,9 @@ def __getitem__(self, ix): # change here how to get the coordinate top, left = wsi.tiles_coords[tile_ix] tile_ops = wsi.tile_ops - img = wsi.get_patch(left, top, tile_ops.ops_width, - tile_ops.ops_height, tile_ops.level) + img = wsi.get_patch( + left, top, tile_ops.ops_width, tile_ops.ops_height, tile_ops.level + ) if self.resize_transform is not None: resize_ops = self.resize_transform[slide_ix] if resize_ops is not None: @@ -133,8 +129,9 @@ def get_sampler_slides(self): less_than_max_taken = [] less_n_tiles = [] - for slide_ix, (n_tiles, start_index, end_index) in enumerate( - zip(self.wsi_n_tiles, self.starts, self.ends)): + for slide_ix, (n_tiles, start_index) in enumerate( + zip(self.wsi_n_tiles, self.starts) + ): if self.max_taken is not None: if n_tiles > self.max_taken: n_tiles = self.max_taken @@ -142,22 +139,24 @@ def get_sampler_slides(self): less_than_max_taken.append(self.wsi_list[slide_ix].image) less_n_tiles.append(n_tiles) - slides.append(Slide(n_tiles, start_index, end_index, - shuffle=self.shuffle_tiles)) + slides.append(Slide(n_tiles, start_index)) total_less = len(less_than_max_taken) if total_less > 0: if total_less > 30: less_than_max_taken = less_than_max_taken[0:30] less_n_tiles = less_n_tiles[0:30] - warn_stats = [f'{i}, {n} tiles' for i, n in zip(less_than_max_taken, less_n_tiles)] - warnings.warn(f"There are {total_less} slides has less than max_taken={self.max_taken}:" - f"{', '.join(warn_stats)}") + warn_stats = [ + f"{i}, {n} tiles" for i, n in zip(less_than_max_taken, less_n_tiles) + ] + warnings.warn( + f"There are {total_less} slides has less than max_taken={self.max_taken}:" + f"{', '.join(warn_stats)}" + ) return slides class SlidesSampler(Sampler): - def __init__(self, slides, batch_size, drop_last=False): super().__init__() self.slides = slides @@ -168,14 +167,12 @@ def __len__(self): return sum([len(s) for s in self.slides]) // self.batch_size def __iter__(self): - _iter_slides = deepcopy(self.slides) exhaust_slides = [] batch = [] while True: - for slide in _iter_slides: t = slide.get_tile() # If tile can be acquired @@ -207,32 +204,35 @@ class SlidesBalancedLoader(DataLoader): the tiles are from different slides """ - def __init__(self, wsi_list, - batch_size=1, - resize=None, - antialias=False, - color_normalize=None, - transform=None, - max_taken=None, - drop_last=False, - shuffle_slides=True, - shuffle_tiles=True, - seed=0, - **kwargs, - ): - dataset = SlidesDataset(wsi_list, - resize=resize, - antialias=antialias, - color_normalize=color_normalize, - transform=transform, - max_taken=max_taken, - shuffle_slides=shuffle_slides, - shuffle_tiles=shuffle_tiles, - seed=seed, - ) - sampler = SlidesSampler(dataset.get_sampler_slides(), - batch_size=batch_size, - drop_last=drop_last) + def __init__( + self, + wsi_list, + batch_size=1, + resize=None, + antialias=False, + color_normalize=None, + transform=None, + max_taken=None, + drop_last=False, + shuffle_slides=True, + shuffle_tiles=True, + seed=0, + **kwargs, + ): + dataset = SlidesDataset( + wsi_list, + resize=resize, + antialias=antialias, + color_normalize=color_normalize, + transform=transform, + max_taken=max_taken, + shuffle_slides=shuffle_slides, + shuffle_tiles=shuffle_tiles, + seed=seed, + ) + sampler = SlidesSampler( + dataset.get_sampler_slides(), batch_size=batch_size, drop_last=drop_last + ) super().__init__( dataset=dataset, diff --git a/lazyslide/models/ctranspath.py b/lazyslide/models/ctranspath.py index 0ddc0cf..486969f 100644 --- a/lazyslide/models/ctranspath.py +++ b/lazyslide/models/ctranspath.py @@ -26,29 +26,32 @@ # ----------------------------------------------------------------------------- -def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): - """ ViT weight initialization + +def _init_vit_weights( + module: nn.Module, name: str = "", head_bias: float = 0.0, jax_impl: bool = False +): + """ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl """ if isinstance(module, nn.Linear): - if name.startswith('head'): + if name.startswith("head"): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) - elif name.startswith('pre_logits'): + elif name.startswith("pre_logits"): lecun_normal_(module.weight) nn.init.zeros_(module.bias) else: if jax_impl: nn.init.xavier_uniform_(module.weight) if module.bias is not None: - if 'mlp' in name: + if "mlp" in name: nn.init.normal_(module.bias, std=1e-6) else: nn.init.zeros_(module.bias) else: - trunc_normal_(module.weight, std=.02) + trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif jax_impl and isinstance(module, nn.Conv2d): @@ -63,20 +66,29 @@ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., def _create_swin_transformer(variant, pretrained=False, **kwargs): cfg = { - 'url': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs + "url": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth", + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.9, + "interpolation": "bicubic", + "fixed_input_size": True, + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": "patch_embed.proj", + "classifier": "head", + **kwargs, } return build_model_with_cfg( - SwinTransformer, variant, pretrained, + SwinTransformer, + variant, + pretrained, default_cfg=cfg, - img_size=cfg['input_size'][-2:], - num_classes=cfg['num_classes'], + img_size=cfg["input_size"][-2:], + num_classes=cfg["num_classes"], pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + **kwargs, + ) def _build_ctranspath_model(): @@ -86,15 +98,18 @@ def _build_ctranspath_model(): embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), - embed_layer=ConvStem + embed_layer=ConvStem, + ) + return _create_swin_transformer( + "swin_tiny_patch4_window7_224", pretrained=False, **model_kwargs ) - return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=False, **model_kwargs) # ----------------------------------------------------------------------------- + class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. + r"""Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: @@ -106,26 +121,34 @@ class WindowAttention(nn.Module): proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): - + def __init__( + self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0.0, proj_drop=0.0 + ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.scale = head_dim**-0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords = torch.stack( + torch.meshgrid([coords_h, coords_w], indexing="ij") + ) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 @@ -137,7 +160,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., pro self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) + trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask: Optional[torch.Tensor] = None): @@ -147,20 +170,33 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -175,7 +211,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. + r"""Swin Transformer Block. Args: dim (int): Number of input channels. @@ -192,9 +228,21 @@ class SwinTransformerBlock(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -206,38 +254,58 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, - attn_drop=attn_drop, proj_drop=drop) + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, to_2tuple(self.window_size)) # nW, window_size, window_size, 1 + mask_windows = window_partition( + img_mask, to_2tuple(self.window_size) + ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None @@ -254,24 +322,36 @@ def forward(self, x): # cyclic shift if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) else: shifted_x = x # partition windows - x_windows = window_partition(shifted_x, to_2tuple(self.window_size)) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = window_partition( + shifted_x, to_2tuple(self.window_size) + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, to_2tuple(self.window_size), H, W) # B H' W' C + shifted_x = window_reverse( + attn_windows, to_2tuple(self.window_size), H, W + ) # B H' W' C # reverse cyclic shift if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) else: x = shifted_x x = x.view(B, H * W, C) @@ -284,7 +364,7 @@ def forward(self, x): class PatchMerging(nn.Module): - r""" Patch Merging Layer. + r"""Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. @@ -333,7 +413,7 @@ def flops(self): class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. @@ -351,10 +431,22 @@ class BasicLayer(nn.Module): use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -362,17 +454,32 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.use_checkpoint = use_checkpoint # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) else: self.downsample = None @@ -391,7 +498,7 @@ def extra_repr(self) -> str: class SwinTransformer(nn.Module): - r""" Swin Transformer + r"""Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 @@ -415,12 +522,29 @@ class SwinTransformer(nn.Module): use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, embed_layer=PatchEmbed, - use_checkpoint=False, weight_init='', **kwargs): + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + embed_layer=PatchEmbed, + use_checkpoint=False, + weight_init="", + **kwargs, + ): super().__init__() self.num_classes = num_classes @@ -433,50 +557,69 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # split image into non-overlapping patches self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) num_patches = self.patch_embed.num_patches self.patch_grid = self.patch_embed.grid_size # absolute position embedding if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) else: self.absolute_pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule # build layers layers = [] for i_layer in range(self.num_layers): - layers += [BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) + layers += [ + BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=( + self.patch_grid[0] // (2**i_layer), + self.patch_grid[1] // (2**i_layer), + ), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if (i_layer < self.num_layers - 1) + else None, + use_checkpoint=use_checkpoint, + ) ] self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) - assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. - if weight_init.startswith('jax'): + assert weight_init in ("jax", "jax_nlhb", "nlhb", "") + head_bias = -math.log(self.num_classes) if "nlhb" in weight_init else 0.0 + if weight_init.startswith("jax"): for n, m in self.named_modules(): _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) else: @@ -484,18 +627,22 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, @torch.jit.ignore def no_weight_decay(self): - return {'absolute_pos_embed'} + return {"absolute_pos_embed"} @torch.jit.ignore def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} + return {"relative_position_bias_table"} def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) def forward_features(self, x): x = self.patch_embed(x) @@ -515,8 +662,15 @@ def forward(self, x): class ConvStem(torch.nn.Module): - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): super().__init__() assert patch_size == 4 @@ -532,8 +686,17 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_l stem = [] input_dim, output_dim = 3, embed_dim // 8 - for l in range(2): - stem.append(torch.nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) + for _ in range(2): + stem.append( + torch.nn.Conv2d( + input_dim, + output_dim, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ) + ) stem.append(torch.nn.BatchNorm2d(output_dim)) stem.append(torch.nn.ReLU(inplace=True)) input_dim = output_dim @@ -545,8 +708,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_l def forward(self, x): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC @@ -556,6 +720,7 @@ def forward(self, x): # ----------------------------------------------------------------------------- + class CTransPathFeatures: """ CTransPath pretrained feature extractor. @@ -573,7 +738,7 @@ class CTransPathFeatures: GitHub: https://github.com/Xiyue-Wang/TransPath """ - tag = 'ctranspath' + tag = "ctranspath" license = """GNU General Public License v3.0""" citation = """ @{wang2022, @@ -593,11 +758,10 @@ def __init__(self, device=None, center_crop=False): self.model.head = torch.nn.Identity().to(self.device) checkpoint_path = hf_hub_download( - repo_id='jamesdolezal/CTransPath', - filename='ctranspath.pth' + repo_id="jamesdolezal/CTransPath", filename="ctranspath.pth" ) td = torch.load(checkpoint_path, map_location=self.device) - self.model.load_state_dict(td['model'], strict=True) + self.model.load_state_dict(td["model"], strict=True) self.model = self.model.to(self.device) self.model.eval() @@ -605,10 +769,8 @@ def __init__(self, device=None, center_crop=False): self.num_features = 768 all_transforms = [transforms.CenterCrop(224)] if center_crop else [] all_transforms += [ - transforms.Lambda(lambda x: x / 255.), - transforms.Normalize( - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225)) + transforms.Lambda(lambda x: x / 255.0), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] self.transform = transforms.Compose(all_transforms) self.preprocess_kwargs = dict(standardize=False) diff --git a/lazyslide/models/hovernet.py b/lazyslide/models/hovernet.py index 2fe9463..d30f579 100644 --- a/lazyslide/models/hovernet.py +++ b/lazyslide/models/hovernet.py @@ -22,7 +22,7 @@ def segmentation_lines(mask_in): Useful for plotting results of tissue detection or other segmentation. """ assert ( - mask_in.dtype == np.uint8 + mask_in.dtype == np.uint8 ), f"Input mask dtype {mask_in.dtype} must be np.uint8" kernel = np.ones((3, 3), np.uint8) dilated = cv2.dilate(mask_in, kernel) @@ -40,15 +40,18 @@ def center_crop_im_batch(batch, dims, batch_order="BCHW"): dims: Amount to be cropped (tuple for H, W) """ assert ( - batch.ndim == 4 + batch.ndim == 4 ), f"ERROR input shape is {batch.shape} - expecting a batch with 4 dimensions total" assert ( - len(dims) == 2 + len(dims) == 2 ), f"ERROR input cropping dims is {dims} - expecting a tuple with 2 elements total" - assert batch_order in { - "BHCW", - "BCHW", - }, f"ERROR input batch order {batch_order} not recognized. Must be one of 'BHCW' or 'BCHW'" + assert ( + batch_order + in { + "BHCW", + "BCHW", + } + ), f"ERROR input batch order {batch_order} not recognized. Must be one of 'BHCW' or 'BCHW'" if dims == (0, 0): # no cropping necessary in this case @@ -87,7 +90,7 @@ def dice_loss(true, logits, eps=1e-3): dice_loss: the Sørensen–Dice loss. """ assert ( - true.dtype == torch.long + true.dtype == torch.long ), f"Input 'true' is of type {true.type}. It should be a long." num_classes = logits.shape[1] if num_classes == 1: @@ -536,7 +539,7 @@ def compute_hv_map(mask): np.ndarray: array of hv maps of shape (2, H, W). First channel corresponds to horizontal and second vertical. """ assert ( - mask.ndim == 2 + mask.ndim == 2 ), f"Input mask has shape {mask.shape}. Expecting a mask with 2 dimensions (H, W)" out = np.zeros((2, mask.shape[0], mask.shape[1])) @@ -615,7 +618,7 @@ def _get_gradient_hv(hv_batch, kernel_size=5): Tuple of (h_grad, v_grad) where each is a Tensor giving horizontal and vertical gradients respectively """ assert ( - hv_batch.shape[1] == 2 + hv_batch.shape[1] == 2 ), f"inputs have shape {hv_batch.shape}. Expecting tensor of shape (B, 2, H, W)" h_kernel, v_kernel = get_sobel_kernels(kernel_size, dt=hv_batch.dtype) @@ -733,12 +736,12 @@ def loss_hovernet(outputs, ground_truth, n_classes=None): nc_loss_ce = 0 loss = ( - np_loss_dice - + np_loss_ce - + hv_loss_mse - + hv_loss_grad - + nc_loss_dice - + nc_loss_ce + np_loss_dice + + np_loss_ce + + hv_loss_mse + + hv_loss_grad + + nc_loss_dice + + nc_loss_ce ) return loss @@ -760,7 +763,7 @@ def remove_small_objs(array_in, min_size): a different integer from 1 to n, where n is the number of total distinct contiguous regions """ assert ( - array_in.dtype == np.uint8 + array_in.dtype == np.uint8 ), f"Input dtype is {array_in.dtype}. Must be np.uint8" # remove elements below size threshold # each contiguous nucleus region gets a unique id @@ -775,7 +778,7 @@ def remove_small_objs(array_in, min_size): def _post_process_single_hovernet( - np_out, hv_out, small_obj_size_thresh=10, kernel_size=21, h=0.5, k=0.5 + np_out, hv_out, small_obj_size_thresh=10, kernel_size=21, h=0.5, k=0.5 ): """ Combine predictions of np channel and hv channel to create final predictions. @@ -856,7 +859,7 @@ def _post_process_single_hovernet( def post_process_batch_hovernet( - outputs, n_classes, small_obj_size_thresh=10, kernel_size=21, h=0.5, k=0.5 + outputs, n_classes, small_obj_size_thresh=10, kernel_size=21, h=0.5, k=0.5 ): """ Post-process HoVer-Net outputs to get a final predicted mask. @@ -950,6 +953,7 @@ def post_process_batch_hovernet( else: return out_detection + # plotting hovernet outputs diff --git a/lazyslide/models/retccl.py b/lazyslide/models/retccl.py index 2c30f16..3e3d176 100644 --- a/lazyslide/models/retccl.py +++ b/lazyslide/models/retccl.py @@ -9,10 +9,19 @@ # ----------------------------------------------------------------------------- + def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) def conv1x1(in_planes, out_planes, stride=1): @@ -23,13 +32,22 @@ def conv1x1(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 @@ -63,12 +81,22 @@ def forward(self, x): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None, momentum_bn=0.1): + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + momentum_bn=0.1, + ): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width, momentum=momentum_bn) @@ -104,7 +132,6 @@ def forward(self, x): class NormedLinear(nn.Module): - def __init__(self, in_features, out_features): super(NormedLinear, self).__init__() self.weight = Parameter(torch.Tensor(in_features, out_features)) @@ -116,11 +143,24 @@ def forward(self, x): class ResNet50(nn.Module): - - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None, two_branch=False, mlp=False, normlinear=False, - momentum_bn=0.1, attention=False, attention_layers=3, return_attn=False): + def __init__( + self, + block, + layers, + num_classes=1000, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + two_branch=False, + mlp=False, + normlinear=False, + momentum_bn=0.1, + attention=False, + attention_layers=3, + return_attn=False, + ): super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -134,8 +174,10 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) self.groups = groups self.base_width = width_per_group self.two_branch = two_branch @@ -143,21 +185,27 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, self.mlp = mlp linear = NormedLinear if normlinear else nn.Linear - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = norm_layer(self.inplanes, momentum=momentum_bn) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) if attention: - self.att_branch = self._make_layer(block, 512, attention_layers, 1, attention=True) + self.att_branch = self._make_layer( + block, 512, attention_layers, 1, attention=True + ) else: self.att_branch = None @@ -166,8 +214,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, if self.mlp: if self.two_branch: self.fc = nn.Sequential( - nn.Linear(512 * block.expansion, 512 * block.expansion), - nn.ReLU() + nn.Linear(512 * block.expansion, 512 * block.expansion), nn.ReLU() ) self.instDis = linear(512 * block.expansion, num_classes) self.groupDis = linear(512 * block.expansion, num_classes) @@ -175,7 +222,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, self.fc = nn.Sequential( nn.Linear(512 * block.expansion, 512 * block.expansion), nn.ReLU(), - linear(512 * block.expansion, num_classes) + linear(512 * block.expansion, num_classes), ) else: self.fc = nn.Linear(512 * block.expansion, num_classes) @@ -184,7 +231,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -199,7 +246,9 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, block, planes, blocks, stride=1, dilate=False, attention=False): + def _make_layer( + self, block, planes, blocks, stride=1, dilate=False, attention=False + ): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -213,23 +262,44 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, attention=F ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer, momentum_bn=self.momentum_bn)) + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + momentum_bn=self.momentum_bn, + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer, momentum_bn=self.momentum_bn)) + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + momentum_bn=self.momentum_bn, + ) + ) if attention: - layers.append(nn.Sequential( - conv1x1(self.inplanes, 128), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - conv1x1(128, 1), - nn.BatchNorm2d(1), - nn.Sigmoid() - )) + layers.append( + nn.Sequential( + conv1x1(self.inplanes, 128), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + conv1x1(128, 1), + nn.BatchNorm2d(1), + nn.Sigmoid(), + ) + ) return nn.Sequential(*layers) @@ -264,6 +334,7 @@ def forward(self, x): # ----------------------------------------------------------------------------- + class RetCCLFeatures: """ RetCCl pretrained feature extractor. @@ -271,7 +342,7 @@ class RetCCLFeatures: GitHub: https://github.com/Xiyue-Wang/RetCCL """ - tag = 'retccl' + tag = "retccl" license = "GNU General Public License v3.0" citation = """ @article{WANG2023102645, @@ -286,7 +357,6 @@ class RetCCLFeatures: """ def __init__(self, device=None, center_crop=False, ckpt=None): - if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device @@ -296,14 +366,11 @@ def __init__(self, device=None, center_crop=False, ckpt=None): num_classes=128, mlp=False, two_branch=False, - normlinear=True + normlinear=True, ) self.model.fc = torch.nn.Identity().to(self.device) if ckpt is None: - ckpt = hf_hub_download( - repo_id='jamesdolezal/RetCCL', - filename='retccl.pth' - ) + ckpt = hf_hub_download(repo_id="jamesdolezal/RetCCL", filename="retccl.pth") elif not isinstance(ckpt, str): raise ValueError(f"Invalid checkpoint path: {ckpt}") td = torch.load(ckpt, map_location=self.device) @@ -315,10 +382,8 @@ def __init__(self, device=None, center_crop=False, ckpt=None): self.num_features = 2048 all_transforms = [transforms.CenterCrop(256)] if center_crop else [] all_transforms += [ - transforms.Lambda(lambda x: x / 255.), - transforms.Normalize( - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225)), + transforms.Lambda(lambda x: x / 255.0), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] self.transform = transforms.Compose(all_transforms) self.preprocess_kwargs = dict(standardize=False) @@ -326,4 +391,4 @@ def __init__(self, device=None, center_crop=False, ckpt=None): # --------------------------------------------------------------------- def __call__(self, *args, **kwargs): - return self.model(*args, **kwargs) \ No newline at end of file + return self.model(*args, **kwargs) diff --git a/lazyslide/normalizer.py b/lazyslide/normalizer.py index d0f6ff5..1f56be9 100644 --- a/lazyslide/normalizer.py +++ b/lazyslide/normalizer.py @@ -3,17 +3,15 @@ class ColorNormalizer(torch.nn.Module): - - T = Compose([ - ToImage(), - ToDtype(torch.float32, scale=True), - Lambda(lambda x: x*255) - ]) + T = Compose( + [ToImage(), ToDtype(torch.float32, scale=True), Lambda(lambda x: x * 255)] + ) def __init__(self, method="macenko"): super().__init__() import torchstain.torch.normalizers as norm + self.method = method if method == "macenko": normalizer = norm.TorchMacenkoNormalizer() diff --git a/lazyslide/readers/base.py b/lazyslide/readers/base.py index 2147d27..118a81c 100644 --- a/lazyslide/readers/base.py +++ b/lazyslide/readers/base.py @@ -53,9 +53,10 @@ def get_metadata(self): return self.metadata @staticmethod - def resize_img(img: np.ndarray, - scale: float, - ): + def resize_img( + img: np.ndarray, + scale: float, + ): dim = np.asarray(img.shape) dim = np.array(dim * scale, dtype=int) return cv2.resize(img, dim) @@ -63,22 +64,27 @@ def resize_img(img: np.ndarray, @staticmethod def _rgba_to_rgb(img): image_array_rgba = np.asarray(img) - image_array = cv2.cvtColor(image_array_rgba, cv2.COLOR_RGBA2RGB).astype(np.uint8) + image_array = cv2.cvtColor(image_array_rgba, cv2.COLOR_RGBA2RGB).astype( + np.uint8 + ) return image_array - def _get_ops_params(self, - level: int = None, - mpp: float = None, - magnification: int = None, ): - + def _get_ops_params( + self, + level: int = None, + mpp: float = None, + magnification: int = None, + ): has_level = level is not None has_mpp = mpp is not None has_mag = magnification is not None # only one argument is allowed if np.sum([has_level, has_mpp, has_mag]) != 1: - raise ValueError("Please specific one and only one argument," - "level, mpp or magnification") + raise ValueError( + "Please specific one and only one argument," + "level, mpp or magnification" + ) ops_resize = False ops_level = 0 @@ -111,8 +117,7 @@ def _get_ops_params(self, def _get_best_level_to_downsample(self, factor): if factor <= 1: - raise ValueError(f"Downsample factor must >= 1, " - f"the input is {factor}") + raise ValueError(f"Downsample factor must >= 1, " f"the input is {factor}") level_downsample = np.array(self.metadata.level_downsample) # Get levels that can be downsample diff --git a/lazyslide/readers/openslide.py b/lazyslide/readers/openslide.py index f197216..b852456 100644 --- a/lazyslide/readers/openslide.py +++ b/lazyslide/readers/openslide.py @@ -49,7 +49,6 @@ def get_patch( return region_rgb def get_level(self, level): - level = self.translate_level(level) width, height = self.slide.level_dimensions[level] diff --git a/lazyslide/readers/utils.py b/lazyslide/readers/utils.py index 959e170..570661d 100644 --- a/lazyslide/readers/utils.py +++ b/lazyslide/readers/utils.py @@ -3,9 +3,7 @@ @njit(cache=True) -def get_crop_left_top_width_height(img_width, img_height, - left, top, - width, height): +def get_crop_left_top_width_height(img_width, img_height, left, top, width, height): top_in = 0 <= top <= img_height left_in = 0 <= left <= img_width bottom_in = 0 <= (top + height) <= img_height @@ -14,9 +12,11 @@ def get_crop_left_top_width_height(img_width, img_height, left_out, right_out = not left_in, not right_in # If extract from region outside image if (top_out and bottom_out) or (left_out and right_out): - raise RuntimeError(f"Extracting region that are completely outside image. \n" - f"Image shape: H, W ({img_height}, {img_width}) \n" - f"Tile: Top, Left, Width, Height ({top}, {left}, {width}, {height})") + raise RuntimeError( + f"Extracting region that are completely outside image. \n" + f"Image shape: H, W ({img_height}, {img_width}) \n" + f"Tile: Top, Left, Width, Height ({top}, {left}, {width}, {height})" + ) if top_out and bottom_in and left_out and right_in: crop_left, crop_top = 0, 0 diff --git a/lazyslide/readers/vips.py b/lazyslide/readers/vips.py index 1d54fd1..4928695 100644 --- a/lazyslide/readers/vips.py +++ b/lazyslide/readers/vips.py @@ -55,12 +55,16 @@ def __init__( super().__init__(file, metadata) - def get_patch(self, - left, top, width, height, - level: int = None, - downsample: float = None, - fill=255, - ): + def get_patch( + self, + left, + top, + width, + height, + level: int = None, + downsample: float = None, + fill=255, + ): """Get a patch by x, y from top-left corner""" level = self.translate_level(level) img = self._get_vips_level(level) @@ -81,18 +85,21 @@ def _get_vips_level(self, level=0): """Lazy load and load only one for all image level""" handler = self.__level_vips_handler.get(level) if handler is None: - handler = vips.Image.new_from_file( - str(self.file), fail=True, level=level) + handler = vips.Image.new_from_file(str(self.file), fail=True, level=level) self.__level_vips_handler[level] = handler return handler @staticmethod def _get_vips_patch(image, left, top, width, height, fill=255): bg = [fill] - crop_left, crop_top, crop_w, crop_h, pos = \ - get_crop_left_top_width_height( - img_width=image.width, img_height=image.height, - left=left, top=top, width=width, height=height) + crop_left, crop_top, crop_w, crop_h, pos = get_crop_left_top_width_height( + img_width=image.width, + img_height=image.height, + left=left, + top=top, + width=width, + height=height, + ) cropped = image.crop(crop_left, crop_top, crop_w, crop_h) if pos is None: return cropped diff --git a/lazyslide/utils.py b/lazyslide/utils.py index aec0ba5..11ae121 100644 --- a/lazyslide/utils.py +++ b/lazyslide/utils.py @@ -25,12 +25,14 @@ def get_reader(reader="auto") -> Type[ReaderBase]: try: import openslide + readers["openslide"] = OpenSlideReader except (ModuleNotFoundError, OSError) as _: pass try: import pyvips as vips + readers["vips"] = VipsReader except (ModuleNotFoundError, OSError) as _: pass @@ -47,8 +49,9 @@ def get_reader(reader="auto") -> Type[ReaderBase]: if reader is not None: return reader elif reader not in reader_candidates: - raise ValueError(f"Reqeusted reader not available, " - f"must be one of {reader_candidates}") + raise ValueError( + f"Reqeusted reader not available, " f"must be one of {reader_candidates}" + ) else: return readers[reader] @@ -87,6 +90,7 @@ def check_wsi_path(path: str | Path, allow_download: bool = True) -> Path: return path raise ValueError("Path must be a URL or Path to existing file.") + @dataclass class TileOps: level: int = 0 diff --git a/lazyslide/wsi.py b/lazyslide/wsi.py index 455e5f0..2f140c3 100644 --- a/lazyslide/wsi.py +++ b/lazyslide/wsi.py @@ -18,8 +18,9 @@ @njit -def create_tiles_top_left(image_shape, tile_h, tile_w, - stride_h=None, stride_w=None, pad=True): +def create_tiles_top_left( + image_shape, tile_h, tile_w, stride_h=None, stride_w=None, pad=True +): """Create the tiles, return only coordination Padding works as follows: @@ -73,7 +74,7 @@ def create_tiles_top_left(image_shape, tile_h, tile_w, @njit -def filter_tiles(mask, tiles_coords, tile_h, tile_w, filter_bg=.8): +def filter_tiles(mask, tiles_coords, tile_h, tile_w, filter_bg=0.8): """Return a binary array that indicate which tile should be left Parameters @@ -89,15 +90,16 @@ def filter_tiles(mask, tiles_coords, tile_h, tile_w, filter_bg=.8): """ use = [] for x, y in tiles_coords: - mask_region = mask[x:x + tile_h, y:y + tile_w] + mask_region = mask[x : x + tile_h, y : y + tile_w] bg_ratio = np.sum(mask_region == 0) / mask_region.size use.append(bg_ratio < filter_bg) return np.array(use, dtype=np.bool_) @njit -def create_tiles_coords_index(image_shape, tile_h, tile_w, - stride_h=None, stride_w=None, pad=True): +def create_tiles_coords_index( + image_shape, tile_h, tile_w, stride_h=None, stride_w=None, pad=True +): """Create the tiles, return coordination that comprise the tiles and the index of points for each rect @@ -153,15 +155,14 @@ def create_tiles_coords_index(image_shape, tile_h, tile_w, ix3 = (ix_height + 1) * (n_tiles_width + 1) + ix_width indices.append([ix1, ix1 + 1, ix3 + 1, ix3]) - return (np.array(coordinates, dtype=np.uint), - np.array(indices, dtype=np.uint)) + return (np.array(coordinates, dtype=np.uint), np.array(indices, dtype=np.uint)) def get_split_image_indices(image_height, image_width, min_side=20000): h, w = image_height, image_width size = h * w n = min_side - if (size > n ** 2) or (h > n) or (w > n): + if (size > n**2) or (h > n) or (w > n): split_h = h > 1.5 * n split_w = w > 1.5 * n @@ -173,8 +174,16 @@ def get_split_image_indices(image_height, image_width, min_side=20000): # If split, return the split chunks # Else, it would take the whole - ix_h = np.linspace(start=0, stop=h, num=n_chunk_h + 1, dtype=int) if split_h else [0, h] - ix_w = np.linspace(start=0, stop=w, num=n_chunk_w + 1, dtype=int) if split_w else [0, w] + ix_h = ( + np.linspace(start=0, stop=h, num=n_chunk_h + 1, dtype=int) + if split_h + else [0, h] + ) + ix_w = ( + np.linspace(start=0, stop=w, num=n_chunk_w + 1, dtype=int) + if split_w + else [0, w] + ) slices = [] for h1, h2 in pairwise(ix_h): @@ -208,9 +217,11 @@ def __init__( self.contours, self.holes = self.h5_file.get_contours_holes() def __repr__(self): - return (f"WSI(image={self.image}, " - f"h5_file={self.h5_file.file}," - f"reader={self._reader_class})") + return ( + f"WSI(image={self.image}, " + f"h5_file={self.h5_file.file}," + f"reader={self._reader_class})" + ) @property def reader(self): @@ -240,10 +251,9 @@ def create_mask(self, transform, name="user", level=-1, save=False): self.h5_file.set_mask(name, mask, level) self.h5_file.save() - def create_tissue_mask(self, name="tissue", level=-1, - chunk=True, chunk_at=20000, - save=False, - **kwargs): + def create_tissue_mask( + self, name="tissue", level=-1, chunk=True, chunk_at=20000, save=False, **kwargs + ): """Create tissue mask using preconfigure segmentation pipeline @@ -266,13 +276,17 @@ def create_tissue_mask(self, name="tissue", level=-1, seg = TissueDetectionHE(**kwargs) # If the image is too large, we will run segmentation by chunk - split_indices = get_split_image_indices(img_height, img_width, min_side=chunk_at) + split_indices = get_split_image_indices( + img_height, img_width, min_side=chunk_at + ) if chunk & (split_indices is not None): mask = np.zeros((img_height, img_width), dtype=np.uint) for row in split_indices: for ixs in row: h1, h2, w1, w2 = ixs - img_chunk = self.reader.get_patch(w1, h1, w2 - w1, h2 - h1, level=level) + img_chunk = self.reader.get_patch( + w1, h1, w2 - w1, h2 - h1, level=level + ) chunk_mask = seg.apply(img_chunk) mask[h1:h2, w1:w2] = chunk_mask del img_chunk # Explicitly release memory @@ -310,15 +324,18 @@ def create_tissue_contours(self, level=-1, save=False, **kwargs): def get_mask(self, name): return self.masks.get(name), self.masks_level.get(name) - def create_tiles(self, tile_px, - stride_px=None, - pad=False, - mpp=None, - tolerance=.05, - mask_name="tissue", - background_fraction=.8, - tile_pts=3, - errors="ignore"): + def create_tiles( + self, + tile_px, + stride_px=None, + pad=False, + mpp=None, + tolerance=0.05, + mask_name="tissue", + background_fraction=0.8, + tile_pts=3, + errors="ignore", + ): """ Parameters ---------- @@ -346,8 +363,10 @@ def create_tiles(self, tile_px, elif isinstance(tile_px, Iterable): tile_h, tile_w = (tile_px[0], tile_px[1]) else: - raise TypeError(f"input tile_px {tile_px} invalid. " - f"Either (H, W), or a single integer for square tiles.") + raise TypeError( + f"input tile_px {tile_px} invalid. " + f"Either (H, W), or a single integer for square tiles." + ) if stride_px is None: stride_h, stride_w = tile_h, tile_w @@ -356,17 +375,21 @@ def create_tiles(self, tile_px, elif isinstance(stride_px, Iterable): stride_h, stride_w = (stride_px[0], stride_px[1]) else: - raise TypeError(f"input stride {stride_px} invalid. " - f"Either (H, W), or a single integer.") + raise TypeError( + f"input stride {stride_px} invalid. " + f"Either (H, W), or a single integer." + ) use_mask = True mask, mask_level = self.get_mask(mask_name) if mask is None: # Try to use contours instead if self.contours is None: - raise NameError(f"Mask with name '{mask_name}' does not exist, " - f"use .create_tissue_contours() or .create_tissue_mask() " - f"to annotate tissue location.") + raise NameError( + f"Mask with name '{mask_name}' does not exist, " + f"use .create_tissue_contours() or .create_tissue_mask() " + f"to annotate tissue location." + ) else: use_mask = False @@ -383,14 +406,18 @@ def create_tiles(self, tile_px, downsample = 1 if downsample < 1: - raise ValueError(f"Cannot perform resize operation " - f"with reqeust mpp={mpp} on image" - f"mpp={self.metadata.mpp}, this will" - f"require up-scaling of image.") + raise ValueError( + f"Cannot perform resize operation " + f"with reqeust mpp={mpp} on image" + f"mpp={self.metadata.mpp}, this will" + f"require up-scaling of image." + ) elif downsample == 1: ops_level = 0 else: - for ix, level_downsample in enumerate(self.metadata.level_downsample): + for ix, level_downsample in enumerate( + self.metadata.level_downsample + ): if lower_ds < level_downsample < upper_ds: downsample = level_downsample ops_level = ix @@ -418,18 +445,20 @@ def create_tiles(self, tile_px, # Filter coords based on mask # TODO: Consider create tiles based on the # bbox of different components - self.tile_ops = TileOps(level=ops_level, - mpp=mpp, - downsample=downsample, - height=tile_h, width=tile_w, - ops_height=ops_tile_h, - ops_width=ops_tile_w, - mask_name=mask_name - ) + self.tile_ops = TileOps( + level=ops_level, + mpp=mpp, + downsample=downsample, + height=tile_h, + width=tile_w, + ops_height=ops_tile_h, + ops_width=ops_tile_w, + mask_name=mask_name, + ) if use_mask: tiles_coords = create_tiles_top_left( - image_shape, ops_tile_h, ops_tile_w, - ops_stride_h, ops_stride_w, pad=pad) + image_shape, ops_tile_h, ops_tile_w, ops_stride_h, ops_stride_w, pad=pad + ) # Map tile level to mask level # Only tile can be scale, mask will not be resized mask_downsample = self.metadata.level_downsample[mask_level] @@ -437,15 +466,18 @@ def create_tiles(self, tile_px, ratio = tile_downsample / mask_downsample down_coords = (tiles_coords * ratio).astype(np.uint32) use_tiles = filter_tiles( - mask, down_coords, - int(ops_tile_h * ratio), int(ops_tile_w * ratio), - filter_bg=background_fraction) + mask, + down_coords, + int(ops_tile_h * ratio), + int(ops_tile_w * ratio), + filter_bg=background_fraction, + ) self.tiles_coords = tiles_coords[use_tiles].copy() else: rect_coords, rect_indices = create_tiles_coords_index( - image_shape, ops_tile_h, ops_tile_w, - ops_stride_h, ops_stride_w, pad=pad) + image_shape, ops_tile_h, ops_tile_w, ops_stride_h, ops_stride_w, pad=pad + ) if len(self.contours) == 0: is_tiles = np.zeros(len(rect_coords), dtype=np.bool_) else: @@ -454,13 +486,31 @@ def create_tiles(self, tile_px, for c in self.contours: # Coerce the point to python int and let the opencv decide the type # Flip x, y beacuse it's different in opencv - is_in.append(np.array([cv2.pointPolygonTest(c, (float(y), float(x)), measureDist=False) \ - for x, y in points]) == 1) + is_in.append( + np.array( + [ + cv2.pointPolygonTest( + c, (float(y), float(x)), measureDist=False + ) + for x, y in points + ] + ) + == 1 + ) if len(self.holes) > 0: for c in self.holes: - is_in.append(np.array([cv2.pointPolygonTest(c, (float(y), float(x)), measureDist=False) \ - for x, y in points]) == -1) + is_in.append( + np.array( + [ + cv2.pointPolygonTest( + c, (float(y), float(x)), measureDist=False + ) + for x, y in points + ] + ) + == -1 + ) is_tiles = np.asarray(is_in).sum(axis=0) == 1 # The number of points for each tiles inside contours @@ -491,30 +541,33 @@ def new_tiles(self, tiles_coords, height, width, level=0, format="top-left"): """ self.tiles_coords = np.asarray(tiles_coords, dtype=np.uint) - if format == 'left-top': + if format == "left-top": self.tiles_coords = self.tiles_coords[:, [1, 0]] height = int(height) width = int(width) - self.tile_ops = TileOps(level=level, - mpp=self.metadata.mpp, - downsample=1, - height=height, - width=width, - ops_height=height, - ops_width=width, - ) + self.tile_ops = TileOps( + level=level, + mpp=self.metadata.mpp, + downsample=1, + height=height, + width=width, + ops_height=height, + ops_width=width, + ) self.h5_file.set_coords(self.tiles_coords) self.h5_file.set_tile_ops(self.tile_ops) self.h5_file.save() def report(self): if self.tile_ops is not None: - print(f"Generate tiles with mpp={self.tile_ops.mpp}, WSI mpp={self.metadata.mpp}\n" - f"Total tiles: {len(self.tiles_coords)}" - f"Use mask: '{self.tile_ops.mask_name}'\n" - f"Generated Tiles in px (H, W): ({self.tile_ops.height}, {self.tile_ops.width})\n" - f"WSI Tiles in px (H, W): ({self.tile_ops.ops_height}, {self.tile_ops.ops_width}) \n" - f"Down sample ratio: {self.tile_ops.downsample}") + print( + f"Generate tiles with mpp={self.tile_ops.mpp}, WSI mpp={self.metadata.mpp}\n" + f"Total tiles: {len(self.tiles_coords)}" + f"Use mask: '{self.tile_ops.mask_name}'\n" + f"Generated Tiles in px (H, W): ({self.tile_ops.height}, {self.tile_ops.width})\n" + f"WSI Tiles in px (H, W): ({self.tile_ops.ops_height}, {self.tile_ops.ops_width}) \n" + f"Down sample ratio: {self.tile_ops.downsample}" + ) @staticmethod def _get_thumbnail(image_arr, size=1000): @@ -529,19 +582,19 @@ def _get_thumbnail(image_arr, size=1000): return x_ratio, thumbnail - def plot_tissue(self, - size=1000, - tiles=False, - edgecolor=".5", - linewidth=1, - contours=False, - contours_color="green", - holes_color="black", - ax=None, - savefig=None, - savefig_kws=None, - ): - + def plot_tissue( + self, + size=1000, + tiles=False, + edgecolor=".5", + linewidth=1, + contours=False, + contours_color="green", + holes_color="black", + ax=None, + savefig=None, + savefig_kws=None, + ): level = self.tile_ops.level if tiles else self.metadata.n_level - 1 image_arr = self.reader.get_level(level) down_ratio, thumbnail = self._get_thumbnail(image_arr, size) @@ -564,8 +617,9 @@ def plot_tissue(self, tile_w = self.tile_ops.width * down_ratio tiles = [Rectangle(t, tile_w, tile_h) for t in coords] - collections = PatchCollection(tiles, facecolor="none", - edgecolor=edgecolor, lw=linewidth) + collections = PatchCollection( + tiles, facecolor="none", edgecolor=edgecolor, lw=linewidth + ) ax.add_collection(collections) @@ -573,26 +627,30 @@ def plot_tissue(self, ratio = (1 / self.metadata.level_downsample[level]) * down_ratio if len(self.contours) > 0: for c in self.contours: - ax.plot(c[:, 0] * ratio, c[:, 1] * ratio, lw=linewidth, - c=contours_color) + ax.plot( + c[:, 0] * ratio, c[:, 1] * ratio, lw=linewidth, c=contours_color + ) if len(self.holes) > 0: for h in self.holes: - ax.plot(h[:, 0] * ratio, h[:, 1] * ratio, lw=linewidth, c=holes_color) + ax.plot( + h[:, 0] * ratio, h[:, 1] * ratio, lw=linewidth, c=holes_color + ) if savefig: savefig_kws = {} if savefig_kws is None else savefig_kws - save_kws = {'dpi': 150, **savefig_kws} + save_kws = {"dpi": 150, **savefig_kws} fig.savefig(savefig, **save_kws) return ax - def plot_mask(self, - name="tissue", - size=1000, - ax=None, - savefig=None, - savefig_kws=None, - ): + def plot_mask( + self, + name="tissue", + size=1000, + ax=None, + savefig=None, + savefig_kws=None, + ): image_arr = self.masks.get(name) if image_arr is None: raise NameError(f"Cannot draw non-exist mask with name '{name}'") @@ -606,7 +664,7 @@ def plot_mask(self, ax.set_axis_off() if savefig: savefig_kws = {} if savefig_kws is None else savefig_kws - save_kws = {'dpi': 150, **savefig_kws} + save_kws = {"dpi": 150, **savefig_kws} fig.savefig(savefig, **save_kws) return ax @@ -626,4 +684,3 @@ def get_tiles_coords(self): @property def has_tiles(self): return (self.tile_ops is not None) and (self.tiles_coords is not None) - diff --git a/tests/test_reader.py b/tests/test_reader.py index 5ebd79e..9735c5d 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -6,31 +6,54 @@ IMAGE_WIDTH = 200 IMAGE_HEIGHT = 100 -test_func = partial(get_crop_left_top_width_height, - img_width=IMAGE_WIDTH, - img_height=IMAGE_HEIGHT) +test_func = partial( + get_crop_left_top_width_height, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT +) def test_get_crop_xy_wh(): - # inside assert test_func(left=50, top=50, width=10, height=10) == (50, 50, 10, 10, None) # upper-left - assert test_func(left=-10, top=-10, width=20, height=20) == (0, 0, 10, 10, "south-east") + assert test_func(left=-10, top=-10, width=20, height=20) == ( + 0, + 0, + 10, + 10, + "south-east", + ) # center-left assert test_func(left=-10, top=50, width=20, height=20) == (0, 50, 10, 20, "east") # lower-left - assert test_func(left=-10, top=90, width=20, height=20) == (0, 90, 10, 10, "north-east") + assert test_func(left=-10, top=90, width=20, height=20) == ( + 0, + 90, + 10, + 10, + "north-east", + ) # upper-center assert test_func(left=50, top=-10, width=20, height=20) == (50, 0, 20, 10, "south") # lower-center assert test_func(left=50, top=90, width=20, height=20) == (50, 90, 20, 10, "north") # upper-right - assert test_func(left=190, top=-10, width=20, height=20) == (190, 0, 10, 10, "south-west") + assert test_func(left=190, top=-10, width=20, height=20) == ( + 190, + 0, + 10, + 10, + "south-west", + ) # center-right assert test_func(left=190, top=50, width=20, height=20) == (190, 50, 10, 20, "west") # lower-right - assert test_func(left=190, top=90, width=20, height=20) == (190, 90, 10, 10, "north-west") + assert test_func(left=190, top=90, width=20, height=20) == ( + 190, + 90, + 10, + 10, + "north-west", + ) def test_get_crop_xy_wh_outside(): @@ -48,4 +71,3 @@ def test_get_crop_xy_wh_outside(): with pytest.raises(RuntimeError): test_func(left=50, top=-20, width=10, height=10) -