Skip to content

Commit

Permalink
Remove the deprecated transform API in pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed Nov 28, 2023
1 parent 16f47fd commit 4078749
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions lazyslide/torch_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.v2 import ToTensor, Normalize, Compose, Resize
import torch
from torch.utils.data import Dataset
from torchvision.transforms.v2 import ToDtype, Normalize, Compose, Resize


class WSIDataset(Dataset):

def __init__(self,
wsi: 'WSI',
wsi,
transform=None,
run_pretrained=False):
self.wsi = wsi
Expand All @@ -19,7 +20,7 @@ def __init__(self,
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

self.transform_ops = [ToTensor(), Normalize(mean=mean, std=std)]
self.transform_ops = [ToDtype(torch.float32, scale=True), Normalize(mean=mean, std=std)]
if transform is not None:
self.transform_ops = [transform]
if self.tile_ops.downsample != 1:
Expand All @@ -34,12 +35,8 @@ def __len__(self):
def __getitem__(self, idx):
coords = self.tiles_coords[idx]
x, y = coords
try:
image = self.wsi.get_patch(y, x,
self.tile_ops.ops_width,
self.tile_ops.ops_height,
self.tile_ops.level)
except Exception as e:
print(x, y, self.tile_ops)
raise e
image = self.wsi.get_patch(y, x,
self.tile_ops.ops_width,
self.tile_ops.ops_height,
self.tile_ops.level)
return self.transform(image)

0 comments on commit 4078749

Please sign in to comment.