1+ from segment_anything import SamPredictor , sam_model_registry
2+ import torchvision
3+ import torch
4+ from PIL import Image
5+
6+ import numpy as np
7+ import os
8+ import xml .etree .ElementTree as ET
9+ from statistics import mean
10+ from torch .nn .functional import threshold , normalize
11+ import torch .nn .functional as F
12+ from segment_anything .utils .transforms import ResizeLongestSide
13+ from typing import List , Tuple
14+
15+ # Pad image - based on SAM
16+ def pad_image (x : torch .Tensor , square_length = 1024 ) -> torch .Tensor :
17+ # C, H, W
18+ h , w = x .shape [- 2 :]
19+ padh = square_length - h
20+ padw = square_length - w
21+ x = F .pad (x , (0 , padw , 0 , padh ))
22+ return x
23+
24+ # Custom dataset
25+ class INC_SAMVOC2012Dataset (object ):
26+ def __init__ (self , voc_root , type ):
27+ self .voc_root = voc_root
28+ self .num_of_data = - 1
29+ self .dataset = {} # Item will be : ["filename", "class_name", [4x bounding boxes coordinates], etc)
30+ self .resizelongestside = ResizeLongestSide (target_length = 1024 )
31+ pixel_mean = [123.675 , 116.28 , 103.53 ]
32+ pixel_std = [58.395 , 57.12 , 57.375 ]
33+ self .pixel_mean = torch .Tensor (pixel_mean ).view (- 1 , 1 , 1 )
34+ self .pixel_std = torch .Tensor (pixel_std ).view (- 1 , 1 , 1 )
35+
36+ # Read through all the samples and output a dictionary
37+ # Key of the dictionary will be idx
38+ # Item of the dictionary will be filename, class id and bounding boxes
39+ annotation_dir = os .path .join (voc_root , "Annotations" )
40+ files = os .listdir (annotation_dir )
41+ files = [f for f in files if os .path .isfile (annotation_dir + '/' + f )] #Filter directory
42+ annotation_files = [os .path .join (annotation_dir , x ) for x in files ]
43+
44+ # Get the name list of the segmentation files
45+ segmentation_dir = os .path .join (voc_root , "SegmentationObject" )
46+ files = os .listdir (segmentation_dir )
47+ files = [f for f in files if os .path .isfile (segmentation_dir + '/' + f )] #Filter directory
48+ segmentation_files = [x for x in files ]
49+
50+
51+ # Based on the type (train/val) to select data
52+ train_val_dir = os .path .join (voc_root , 'ImageSets/Segmentation/' )
53+ if type == 'train' :
54+ txt_file_name = 'train.txt'
55+ elif type == 'val' :
56+ txt_file_name = 'val.txt'
57+ else :
58+ print ('Error! Type of dataset should be ' 'train' ' or ' 'val' ' ' )
59+
60+ with open (train_val_dir + txt_file_name , 'r' ) as f :
61+ permitted_files = []
62+ for row in f :
63+ permitted_files .append (row .rstrip ('\n ' ))
64+
65+ for file in annotation_files :
66+ file_name = file .split ('/' )[- 1 ].split ('.xml' )[0 ]
67+
68+ if not (file_name in permitted_files ):
69+ continue #skip the file
70+
71+ if file_name + '.png' in segmentation_files : # check that if there is any related segmentation file for this annotation
72+ tree = ET .parse (file )
73+ root = tree .getroot ()
74+ for child in root :
75+ if child .tag == 'object' :
76+ details = [file_name ]
77+ for node in child :
78+ if node .tag == 'name' :
79+ object_name = node .text
80+ if node .tag == 'bndbox' :
81+ for coordinates in node :
82+ if coordinates .tag == 'xmax' :
83+ xmax = int (coordinates .text )
84+ if coordinates .tag == 'xmin' :
85+ xmin = int (coordinates .text )
86+ if coordinates .tag == 'ymax' :
87+ ymax = int (coordinates .text )
88+ if coordinates .tag == 'ymin' :
89+ ymin = int (coordinates .text )
90+ boundary = [xmin , ymin , xmax , ymax ]
91+ details .append (object_name )
92+ details .append (boundary )
93+ self .num_of_data += 1
94+ self .dataset [self .num_of_data ] = details
95+
96+ def __len__ (self ):
97+ return self .num_of_data
98+
99+ # Preprocess the segmentation mask. Output only 1 object semantic information.
100+ def preprocess_segmentation (self , filename , bounding_box , pad = True ):
101+
102+ #read the semantic mask
103+ segment_mask = Image .open (self .voc_root + 'SegmentationObject/' + filename + '.png' )
104+ segment_mask_np = torchvision .transforms .functional .pil_to_tensor (segment_mask )
105+
106+ #Crop the segmentation based on the bounding box
107+ xmin , ymin = int (bounding_box [0 ]), int (bounding_box [1 ])
108+ xmax , ymax = int (bounding_box [2 ]), int (bounding_box [3 ])
109+ cropped_mask = segment_mask .crop ((xmin , ymin , xmax , ymax ))
110+ cropped_mask_np = torchvision .transforms .functional .pil_to_tensor (cropped_mask )
111+
112+ #Count the majority element
113+ bincount = np .bincount (cropped_mask_np .reshape (- 1 ))
114+ bincount [0 ] = 0 #Remove the black pixel
115+ if (bincount .shape [0 ] >= 256 ):
116+ bincount [255 ] = 0 #Remove the white pixel
117+ majority_element = bincount .argmax ()
118+
119+ #Based on the majority element, binary mask the segmentation
120+ segment_mask_np [np .where ((segment_mask_np != 0 ) & (segment_mask_np != majority_element ))] = 0
121+ segment_mask_np [segment_mask_np == majority_element ] = 1
122+
123+ #Pad the segment mask to 1024x1024 (for batching in dataloader)
124+ if pad :
125+ segment_mask_np = pad_image (segment_mask_np )
126+
127+ return segment_mask_np
128+
129+ # Preprocess the image to an appropriate format for SAM
130+ def preprocess_image (self , img ):
131+ # ~= predictor.py - set_image()
132+ img = np .array (img )
133+ input_image = self .resizelongestside .apply_image (img )
134+ input_image_torch = torch .as_tensor (input_image , device = 'cpu' )
135+ input_image_torch = input_image_torch .permute (2 , 0 , 1 ).contiguous ()
136+ input_image_torch = (input_image_torch - self .pixel_mean ) / self .pixel_std #normalize
137+ original_size = img .shape [:2 ]
138+ input_size = tuple (input_image_torch .shape [- 2 :])
139+
140+ return pad_image (input_image_torch ), original_size , input_size
141+
142+ def __getitem__ (self , idx ):
143+ data = self .dataset [idx ]
144+ filename , classname = data [0 ], data [1 ]
145+ bounding_box = data [2 ]
146+
147+ # No padding + preprocessing
148+ mask_gt = self .preprocess_segmentation (filename , bounding_box , pad = False )
149+
150+ image , original_size , input_size = self .preprocess_image (Image .open (self .voc_root + 'JPEGImages/' + filename + '.jpg' )) # read the image
151+ prompt = bounding_box # bounding box - input_boxes x1, y1, x2, y2
152+ training_data = {}
153+ training_data ['image' ] = image
154+ training_data ["original_size" ] = original_size
155+ training_data ["input_size" ] = input_size
156+ training_data ["ground_truth_mask" ] = mask_gt
157+ training_data ["prompt" ] = prompt
158+ return (training_data , mask_gt ) #data, label
159+
160+
161+ class INC_SAMVOC2012Dataloader :
162+ def __init__ (self , batch_size , ** kwargs ):
163+ self .batch_size = batch_size
164+ self .dataset = []
165+ ds = INC_SAMVOC2012Dataset (kwargs ['voc_root' ], kwargs ['type' ])
166+ # operations to add (input_data, label) pairs into self.dataset
167+ for i in range (len (ds )):
168+ self .dataset .append (ds [i ])
169+
170+
171+ def __iter__ (self ):
172+ for input_data , label in self .dataset :
173+ yield input_data , label
0 commit comments