diff --git a/lazyslide/torch_dataset.py b/lazyslide/torch_dataset.py index 7c29741..127a02c 100644 --- a/lazyslide/torch_dataset.py +++ b/lazyslide/torch_dataset.py @@ -1,11 +1,12 @@ -from torch.utils.data import Dataset, DataLoader -from torchvision.transforms.v2 import ToTensor, Normalize, Compose, Resize +import torch +from torch.utils.data import Dataset +from torchvision.transforms.v2 import ToDtype, Normalize, Compose, Resize class WSIDataset(Dataset): def __init__(self, - wsi: 'WSI', + wsi, transform=None, run_pretrained=False): self.wsi = wsi @@ -19,7 +20,7 @@ def __init__(self, mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) - self.transform_ops = [ToTensor(), Normalize(mean=mean, std=std)] + self.transform_ops = [ToDtype(torch.float32, scale=True), Normalize(mean=mean, std=std)] if transform is not None: self.transform_ops = [transform] if self.tile_ops.downsample != 1: @@ -34,12 +35,8 @@ def __len__(self): def __getitem__(self, idx): coords = self.tiles_coords[idx] x, y = coords - try: - image = self.wsi.get_patch(y, x, - self.tile_ops.ops_width, - self.tile_ops.ops_height, - self.tile_ops.level) - except Exception as e: - print(x, y, self.tile_ops) - raise e + image = self.wsi.get_patch(y, x, + self.tile_ops.ops_width, + self.tile_ops.ops_height, + self.tile_ops.level) return self.transform(image)