Skip to content

Commit

Permalink
building out dataloader properly for training
Browse files Browse the repository at this point in the history
  • Loading branch information
johnathanchiu committed Sep 30, 2024
1 parent 3b79930 commit 5a2d07d
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 50 deletions.
27 changes: 0 additions & 27 deletions model/seg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,6 @@ def __getitem__(self, idx):

return strips, labels # Return the strips and their corresponding labels

def create_strips_and_labels(self, image, img_name):
# Convert image to numpy array for processing
image_array = np.array(image)
height, width = image_array.shape[:2]
strip_height = 32 # Define the height of each strip
strips = []
labels = []

# Get bounding boxes for the current image
bboxes = self.bbox_data.get(
os.path.basename(img_name), []
) # Get bounding boxes for the image

# Create strips from the image
for y in range(0, height, strip_height):
strip = image_array[y : y + strip_height, :] # Get a strip
strips.append(strip)

# Check if any bounding box intersects with the current strip
label = self.check_intersection(bboxes, y, strip_height)
labels.append(label)

return (
strips,
labels,
) # Return the list of strips and their corresponding labels

def check_intersection(self, bboxes, y, strip_height):
# Check if any bounding box intersects with the strip
for bbox in bboxes:
Expand Down
26 changes: 26 additions & 0 deletions model/seg/imutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def create_strips_and_labels(self, image, img_name):
# Convert image to numpy array for processing
image_array = np.array(image)
height, width = image_array.shape[:2]
strip_height = 32 # Define the height of each strip
strips = []
labels = []

# Get bounding boxes for the current image
bboxes = self.bbox_data.get(
os.path.basename(img_name), []
) # Get bounding boxes for the image

# Create strips from the image
for y in range(0, height, strip_height):
strip = image_array[y : y + strip_height, :] # Get a strip
strips.append(strip)

# Check if any bounding box intersects with the current strip
label = self.check_intersection(bboxes, y, strip_height)
labels.append(label)

return (
strips,
labels,
) # Return the list of strips and their corresponding labels
170 changes: 147 additions & 23 deletions nb.ipynb

Large diffs are not rendered by default.

0 comments on commit 5a2d07d

Please sign in to comment.