Skip to content

Utility to draw Semantic Segmentation Masks #3272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
oke-aditya opened this issue Jan 21, 2021 · 16 comments · Fixed by #3330
Closed

Utility to draw Semantic Segmentation Masks #3272

oke-aditya opened this issue Jan 21, 2021 · 16 comments · Fixed by #3330

Comments

@oke-aditya
Copy link
Contributor

🚀 Feature

We recently added utility to draw bounding boxes, which works really well with detection models. #2785 #2556
It might be nice to draw segmentation masks which we obtain from instance segmentation models.

Motivation

Same as bounding box utils. It is very useful to have these. It reduces the dependence of users over other plotting libraries.

Pitch

Our API should be compatible with segmentation models, so we should probably use Tensors.
I think most params remain as same as the previous util. This keeps consistency too.

@torch.no_grad()
def draw_segmentation_masks(image: torch.Tensor,
    masks: torch.Tensor,
    labels: Optional[List[str]] = None,
    colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
    width: int = 1,
    font: Optional[str] = None,
    font_size: int = 10)

We might need to see a method using which we can draw with PIL.
We used draw.rectangle() to draw a box in utils, maybe there is some functionality in PIL that can help us draw shapes.

Alternatives

Let's discuss further, how this API would work with our models.
Also, it would be nice if this works directly for instance segmentation model Mask RCNN.

@datumbox
Copy link
Contributor

@oke-aditya +1 on the proposal. I like the fact of providing a similar API as for draw_bounding_boxes().

@oke-aditya
Copy link
Contributor Author

I would also suggest adding an example notebook/script / tutorial where we can show how both the utils work.
I will give a try at this using PIL and post a sample result here, let's see if it is satisfactory.

@datumbox
Copy link
Contributor

I will give a try at this using PIL and post a sample result here, let's see if it is satisfactory.

Perfect! I think using PIL should be fine given that's how the bounding box version works. Looking forward to it.

@oke-aditya oke-aditya changed the title Utility to draw Segmentation Masks Utility to draw Semantic Segmentation Masks Jan 23, 2021
@oke-aditya
Copy link
Contributor Author

oke-aditya commented Jan 23, 2021

Hey @datumbox. For Instance segmentation model we can do something like #3280.

I'm looking into semantic segmentation models,
The output of model seems to be the probability of each pixel in the image into 1 / 21 (pretrained classes, if it is COCO)

I guess then the predicted label is argmax() of the n classes on each pixel.

I think colors remain the same.
Maybe a user simply passes a Tuple(int, int, int) each containing mapping to ith label?
Or the user passes a List[str] each mapping to ith label?

I think we can use ImageDraw.point to draw each point?
This would be very slow I think for bigger images. Since we fill color on every pixel of tensor.

The width makes little sense here since if we will be drawing each pixel on image?

@datumbox
Copy link
Contributor

datumbox commented Jan 25, 2021

I guess then the predicted label is argmax() of the n classes on each pixel.

Agreed.

I think colors remain the same.

Yes they should remain the same. We currently support both tuples of RGB ints and strings. See an example we have on the tests:

colors = ["green", "#FF00FF", (0, 255, 0), "red"]

Or the user passes a List[str] each mapping to ith label?

We should use colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, like on boxes.

This would be very slow I think for bigger images. Since we fill color on every pixel of tensor.

Worth investigating a couple of options and measure them. If you send a PR and tag me, I'm happy to share feedback.

The width makes little sense here since if we will be drawing each pixel on image?

Yes, I think you are right.

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Jan 25, 2021

Great ! I will send a PR let's investigate few options 😀

This is what I have in my mind.

image

Taken from FCN paper

@datumbox
Copy link
Contributor

@oke-aditya Awesome, looking forward to it.

I'm actually doing work around segmentation, so I'm definitely going to use your prototypes to validate the models. :)

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Jan 30, 2021

I did have a go at this 😃 Here is code to reproduce a simple mask using draw.point. It is not that slow

image = torch.ones(1, 3, 224, 224)
fcn = fcn_resnet50(pretrained=False, num_classes=21)
fcn = fcn.eval()

with torch.no_grad():
   fcn_out = fcn(image)['out'][0]
    # print(fcn_out)

output_predictions = fcn_out.argmax(0)
# print(output_predictions)
# print(output_predictions.shape)

img_preds = output_predictions.to(torch.int64).tolist()

image = torch.ones((3, 224, 224), dtype=torch.uint8)
ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)

labels = ["blue", "red", "purple", "orange", "red", "purple", "orange", "red",
              "purple", "black", "red", "yellow", "black", "red", "yellow", "orange", "red",
              "yellow", "aqua", "red", "yellow", "aqua"]

draw = ImageDraw.Draw(img_to_draw)

for i in range(len(output_predictions)):
    for j in range(len(output_predictions)):
        draw.point((i, j), fill=labels[output_predictions[i][j]])

img_to_draw.save("temp.png")

temp

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Jan 30, 2021

Some points, I'm unclear about

  1. Is the user interested in plotting any text here ? If yes, then where and how ?
    If not, then we can forgo labels, font, font size parameters.

  2. This is not very slow, it takes under 1 second for a decent sized image.

  3. How do we assign default colors ? Otherwise everything will be colored with one color !

@datumbox
Copy link
Contributor

datumbox commented Feb 1, 2021

@oke-aditya Thanks for providing the prototype implementation. This is really helpful because it raises lots of interesting questions about the proposed API.

image = torch.ones((3, 224, 224), dtype=torch.uint8)

From the above, I understand that you create a new image after making the prediction and before start drawing. Do you intentionally propose to create a new picture that contains only the masks OR you plan to apply the mask on top of the original image (with transparency) as we do on the bounding box util?

Concerning your questions, here are some thoughts:

  1. Good point about the labels. I can't think of a way of including them without completely cluttering the picture. In most visualizations of semantic segmentation, the labels are omitted, so I think we could do the same here. The only cases I've seen of using the labels was to build a legend, which I'm not sure is useful here. If you have other ideas, let me know!
  2. That's probably because we write one pixel at a time. Vectorizing the operation can probably speed up the execution. See example.
  3. I think in this util the colours might need to be a mandatory field. Most-likely I would require the user to provide a list of colours equal to the number of classes. Alternative ways include defining a palette, choosing random colours or repeating input colours but all these have been discussed and turned down on the earlier PR versions of the bounding box util because they require a large number of assumptions.

@fmassa if you have any thoughts on the above, pitch in.

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Feb 1, 2021

I think make_grid avoids in-place modification, and we should avoid in-place operations for all plotting utils.

Maybe people want to have a look at the predictions, and do further post-processing to the image, then redraw the predictions. I guess we should just do new_image = image.clone() and avoid in place plotting for all utils.

I am not sure mask on the original image with transparency will be nice here. Maybe that is better for Instance segmentation than semantic?

  1. I too think we can drop labels, Even the paper outputs attached above do not use labels. Plus we cannot create a consistent plotting of labels without overlap.

  2. Vectorizing is fine, though this isn't very slow. I'm not sure how vectorizing will work with PIL.

  3. Colors I guess should be mandatory, coz semantic segmentation is highly dependent. One way of color palette is here

# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)

This is quite fast and good, but we lose control over colors. which can be easily supplied by a list or tuple.

Taken from Colab Notebook for instance segmentation. This is share to PyTorch hub,

@datumbox
Copy link
Contributor

datumbox commented Feb 1, 2021

A couple clarifications and additional notes on my side:

we should avoid in-place operations for all plotting utils.

I agree we should not modify the image in place. My question was whether you planned to put transparency on top of the original image or if you wanted to create a new non-transparent image using only the classifications. If you intend to do the latter, then providing the image tensor is unnecessary because we only use the output predictions.

This is quite fast and good, but we lose control over colours.

The example you provide follows the second approach (no transparency) and uses Palette image types. Both PIL and TorchVision support palettes and as you indicate they allow you to build segmentation masks even for massive images. Perhaps to avoid losing control over colours, we could allow the user to either define a complete list of colours for all classes. If they don't provide such a complete list, we can use a similar colour palette approach as your snippet.

I'm not sure how vectorizing will work with PIL.

As per your example snippet, using palettes avoids calling PIL's draw.point method. Instead all the processing is done in parallel using PyTorch tensors which is very fast.

Thoughts?

@oke-aditya
Copy link
Contributor Author

I had a brief thought about all points.

  1. Yes, providing the image tensor is unnecessary. But let's go for consistency in utils, it doesn't hurt much.
    It would be surprising for a user to remember that we don't pass image only in this util.

We can just do new_image = image.clone() and continue plotting in new_image.
Also, probably we should do this for bounding boxes too and avoid in-place operation there.

  1. The above example is vectorized ! and has good defaults.
    We could also ask the user to optionally provide a list of colors for each class.

So the new API proposal 😄

@torch.no_grad()
def draw_segmentation_masks(
    image: torch.Tensor,
    masks: torch.Tensor,
    colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
)

Args:
        image (Tensor): Tensor of shape (C x H x W)
        masks (Tensor): Tensor of shape (H, W). Each containing predicted class.
        colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks.
        The colors can be represented as `str` or `Tuple[int, int, int]`.
        It should be equal to the number of output classes.

@datumbox
Copy link
Contributor

datumbox commented Feb 1, 2021

Looks very reasonable proposal to me. Thanks for looking into it. :)

One last question. What do you think will happen if you take the image produced by the masks using the process you describe above and do a weighted linear combination with the original image? I think if the background of the palette is set to be transparent instead of black, this operation might project the colour of the masks into the original image. Thoughts?

@oke-aditya
Copy link
Contributor Author

I too think so, if we set the background with transparency channel we might be able to project it into the image.
And additionally have a fill parameter (like we did for boxes). Let me have run with this and post the outputs 😃

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Feb 2, 2021

Hi, @datumbox I opened PR #3330 since we are quite ok with the API now.

Here is the output. I think I'm missing something really small, a transparent background instead of black.
Here is code.

image = torch.ones(1, 3, 224, 224)
fcn = fcn_resnet50(pretrained=False, num_classes=21)
fcn = fcn.eval()

with torch.no_grad():
    fcn_out = fcn(image)['out'][0]

masks = fcn_out.argmax(0)
image = torch.ones((3, 224, 224), dtype=torch.uint8)
colors = ["blue", "red", "purple", "orange", "red", "purple", "orange", "red",
              "purple", "black", "red", "yellow", "black", "red", "yellow", "orange", "red",
              "yellow", "aqua", "red", "yellow", "aqua"]

result = draw_segmentation_masks(image, masks, colors=colors)
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save("draw_masks_util.png")

draw_masks_util

I still think somehow the colors are bit spoilt maybe due to black background?

Feel free to comment on PR. I think I'm missing just a small trick to make masks transparent over image 😄

This was referenced Feb 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants