@@ -160,9 +160,9 @@ def draw_bounding_boxes(
160160 the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
161161 `0 <= ymin < ymax < H`.
162162 labels (List[str]): List containing the labels of bounding boxes.
163- colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]] ): List containing the colors
164- or a single color for all of the bounding boxes. The colors can be represented as `str` or
165- `Tuple[int, int, int] `.
163+ colors (color or list of colors, optional ): List containing the colors
164+ of the boxes or single color for all boxes. The color can be represented as
165+ PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)` `.
166166 fill (bool): If `True` fills the bounding box with specified color.
167167 width (int): Width of bounding box.
168168 font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
@@ -231,7 +231,7 @@ def draw_segmentation_masks(
231231 image : torch .Tensor ,
232232 masks : torch .Tensor ,
233233 alpha : float = 0.8 ,
234- colors : Optional [List [Union [str , Tuple [int , int , int ]]]] = None ,
234+ colors : Optional [Union [ List [Union [str , Tuple [int , int , int ]]], str , Tuple [ int , int , int ]]] = None ,
235235) -> torch .Tensor :
236236
237237 """
@@ -243,10 +243,10 @@ def draw_segmentation_masks(
243243 masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
244244 alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
245245 0 means full transparency, 1 means no transparency.
246- colors (list or None ): List containing the colors of the masks. The colors can
247- be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
248- When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
249- with one element. By default, random colors are generated for each mask.
246+ colors (color or list of colors, optional ): List containing the colors
247+ of the masks or single color for all masks. The color can be represented as
248+ PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
249+ By default, random colors are generated for each mask.
250250
251251 Returns:
252252 img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
@@ -289,8 +289,7 @@ def draw_segmentation_masks(
289289 for color in colors :
290290 if isinstance (color , str ):
291291 color = ImageColor .getrgb (color )
292- color = torch .tensor (color , dtype = out_dtype )
293- colors_ .append (color )
292+ colors_ .append (torch .tensor (color , dtype = out_dtype ))
294293
295294 img_to_draw = image .detach ().clone ()
296295 # TODO: There might be a way to vectorize this
@@ -301,6 +300,6 @@ def draw_segmentation_masks(
301300 return out .to (out_dtype )
302301
303302
304- def _generate_color_palette (num_masks ):
303+ def _generate_color_palette (num_masks : int ):
305304 palette = torch .tensor ([2 ** 25 - 1 , 2 ** 15 - 1 , 2 ** 21 - 1 ])
306305 return [tuple ((i * palette ) % 255 ) for i in range (num_masks )]
0 commit comments