22# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
33
44import os
5- import numpy as np
65import torch
7- from PIL import Image
86
97import torchvision
108from torchvision .models .detection .faster_rcnn import FastRCNNPredictor
119from torchvision .models .detection .mask_rcnn import MaskRCNNPredictor
10+ from torchvision .io import read_image
11+ from torchvision .ops .boxes import masks_to_boxes
12+ from torchvision import datapoints as dp
13+ from torchvision .transforms .v2 import functional as F
14+ from torchvision .transforms import v2 as T
15+
1216
1317from engine import train_one_epoch , evaluate
1418import utils
15- import transforms as T
1619
1720
18- class PennFudanDataset (object ):
21+ class PennFudanDataset (torch . utils . data . Dataset ):
1922 def __init__ (self , root , transforms ):
2023 self .root = root
2124 self .transforms = transforms
@@ -28,47 +31,36 @@ def __getitem__(self, idx):
2831 # load images and masks
2932 img_path = os .path .join (self .root , "PNGImages" , self .imgs [idx ])
3033 mask_path = os .path .join (self .root , "PedMasks" , self .masks [idx ])
31- img = Image .open (img_path ).convert ("RGB" )
32- # note that we haven't converted the mask to RGB,
33- # because each color corresponds to a different instance
34- # with 0 being background
35- mask = Image .open (mask_path )
36-
37- mask = np .array (mask )
34+ img = read_image (img_path )
35+ mask = read_image (mask_path )
3836 # instances are encoded as different colors
39- obj_ids = np .unique (mask )
37+ obj_ids = torch .unique (mask )
4038 # first id is the background, so remove it
4139 obj_ids = obj_ids [1 :]
40+ num_objs = len (obj_ids )
4241
4342 # split the color-encoded mask into a set
4443 # of binary masks
45- masks = mask == obj_ids [:, None , None ]
44+ masks = ( mask == obj_ids [:, None , None ]). to ( dtype = torch . uint8 )
4645
4746 # get bounding box coordinates for each mask
48- num_objs = len (obj_ids )
49- boxes = []
50- for i in range (num_objs ):
51- pos = np .where (masks [i ])
52- xmin = np .min (pos [1 ])
53- xmax = np .max (pos [1 ])
54- ymin = np .min (pos [0 ])
55- ymax = np .max (pos [0 ])
56- boxes .append ([xmin , ymin , xmax , ymax ])
57-
58- boxes = torch .as_tensor (boxes , dtype = torch .float32 )
47+ boxes = masks_to_boxes (masks )
48+
5949 # there is only one class
6050 labels = torch .ones ((num_objs ,), dtype = torch .int64 )
61- masks = torch .as_tensor (masks , dtype = torch .uint8 )
6251
63- image_id = torch . tensor ([ idx ])
52+ image_id = idx
6453 area = (boxes [:, 3 ] - boxes [:, 1 ]) * (boxes [:, 2 ] - boxes [:, 0 ])
6554 # suppose all instances are not crowd
6655 iscrowd = torch .zeros ((num_objs ,), dtype = torch .int64 )
6756
57+ # Wrap sample and targets into torchvision datapoints:
58+ img = dp .Image (img )
59+
6860 target = {}
69- target ["boxes" ] = boxes
61+ target ["boxes" ] = dp .BoundingBoxes (boxes , format = "XYXY" , canvas_size = F .get_size (img ))
62+ target ["masks" ] = dp .Mask (masks )
7063 target ["labels" ] = labels
71- target ["masks" ] = masks
7264 target ["image_id" ] = image_id
7365 target ["area" ] = area
7466 target ["iscrowd" ] = iscrowd
@@ -81,9 +73,10 @@ def __getitem__(self, idx):
8173 def __len__ (self ):
8274 return len (self .imgs )
8375
76+
8477def get_model_instance_segmentation (num_classes ):
85- # load an instance segmentation model pre-trained pre-trained on COCO
86- model = torchvision .models .detection .maskrcnn_resnet50_fpn (pretrained = True )
78+ # load an instance segmentation model pre-trained on COCO
79+ model = torchvision .models .detection .maskrcnn_resnet50_fpn (weights = "DEFAULT" )
8780
8881 # get number of input features for the classifier
8982 in_features = model .roi_heads .box_predictor .cls_score .in_features
@@ -103,9 +96,11 @@ def get_model_instance_segmentation(num_classes):
10396
10497def get_transform (train ):
10598 transforms = []
106- transforms .append (T .ToTensor ())
99+ transforms .append (T .ToImage ())
107100 if train :
108101 transforms .append (T .RandomHorizontalFlip (0.5 ))
102+ transforms .append (T .ToDtype (torch .float , scale = True ))
103+ transforms .append (T .ToPureTensor ())
109104 return T .Compose (transforms )
110105
111106
@@ -160,6 +155,6 @@ def main():
160155 evaluate (model , data_loader_test , device = device )
161156
162157 print ("That's it!" )
163-
158+
164159if __name__ == "__main__" :
165160 main ()
0 commit comments