Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune guide for semantic segmentation #18640

Merged
merged 11 commits into from
Sep 2, 2022
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
title: Automatic speech recognition
- local: tasks/image_classification
title: Image classification
- local: tasks/semantic_segmentation
title: Semantic segmentation
title: Fine-tune for downstream tasks
- local: run_scripts
title: Train with a script
Expand Down
284 changes: 284 additions & 0 deletions docs/source/en/tasks/semantic_segmentation.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Semantic segmentation

<Youtube id="dKE8SIt9C-w"/>

Semantic segmentation assigns a label or class to each individual pixel of an image. There are several types of segmentation tasks, and in the case of semantic segmentation, it doesn't differentiate between unique instances of the same object. Both objects are given the same label. Common real-world applications of semantic segmentation include training self-driving cars to identify pedestrians and important traffic information, identifying cells and abnormalities in medical imagery, and monitoring environmental changes from satellite imagery.
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

This guide will show you how to finetune [SegFormer](https://huggingface.co/docs/transformers/main/en/model_doc/segformer#segformer) on the [SceneParse150](https://huggingface.co/datasets/scene_parse_150) dataset.

<Tip>

See the image segmentation [task page](https://huggingface.co/tasks/image-segmentation) for more information about its associated models, datasets, and metrics.

</Tip>

Before you begin, make sure you have all the necessary libraries installed:

```bash
pip install -q datasets transformers evaluate
```

## Load SceneParse150 dataset

Load the first 50 examples of the SceneParse150 dataset from the 🤗 Datasets library so you can quickly train and test a model:

```py
>>> from datasets import load_dataset

>>> ds = load_dataset("scene_parse_150", split="train[:50]")
```

Split this dataset into a train and test set:

```py
>>> ds = ds.train_test_split(test_size=0.2)
>>> train_ds = ds["train"]
>>> test_ds = ds["test"]
```

Then take a look at an example:

```py
>>> train_ds[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
'scene_category': 368}
```

There is an `image`, an `annotation` (this is the segmentation map or label), and a `scene_category` field that describes the image scene like kitchen or office. Both `image` and `annotation` are PIL images.
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

You'll also want to create a dictionary that maps a label id to a label class which will be useful when you set up the model later. Download the mappings from the Hub and create the `id2label` and `label2id` dictionaries:

```py
>>> import json
>>> from huggingface_hub import cached_download, hf_hub_url

>>> repo_id = "datasets/huggingface/label-files"
>>> filename = "ade20k-id2label.json"
>>> id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
```

## Preprocess

Next, load a SegFormer feature extractor to process the image into an array of tensors. Some datasets, like this one, uses the zero-index as the background class. However, the background class isn't included in the 150 classes, so you'll need to set `reduce_labels=True` to subtract one from all the labels. The zero-index is replaced by `255` so it's ignored by SegFormer's loss function:
Copy link
Contributor

@NielsRogge NielsRogge Aug 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, note that we're currently in the process of replacing feature extractors by image processors, as feature extractor is a confusing name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up, I'll update this again once the image processors are ready

stevhliu marked this conversation as resolved.
Show resolved Hide resolved

```py
>>> from transformers import AutoFeatureExtractor

>>> feature_extractor = AutoFeatureExtractor.from_pretrained("nvidia/mit-b0", reduce_labels=True)
```

It is common to apply some data augmentations to an image dataset to make a model more robust against overfitting. In this guide, you'll use the [`ColorJitter`](https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html) function from [torchvision](https://pytorch.org/vision/stable/index.html) to randomly change the color properties of an image:

```py
>>> from torchvision.transforms import ColorJitter

>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
```

Now create a preprocessing function that'll apply the `jitter` transform and return the `pixel_values` and `labels` of an image. `pixel_values` is the expected model input and it is generated by the feature extractor. You should only `jitter` the images in the training set. For the `labels` and the test set, the feature extractor will only crop and normalize the `images` and `labels`.
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

```py
>>> def train_transforms(example_batch):
... images = [jitter(x) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = feature_extractor(images, labels)
... return inputs


>>> def val_transforms(example_batch):
... images = [x for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = feature_extractor(images, labels)
... return inputs
```

To apply the `jitter` over the entire dataset, use 🤗 Datasets [`~datasets.Dataset.set_transform`] function. The transform is applied on the fly which is faster and consumes less disk space:
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

```py
>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)
```

## Train

<frameworkcontent>
<pt>
stevhliu marked this conversation as resolved.
Show resolved Hide resolved
Load SegFormer with [`AutoModelForSemanticSegmentation`], specify the number of labels, and pass the model the mapping between label ids and label classes:
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

```py
>>> from transformers import AutoModelForSemanticSegmentation

>>> pretrained_model_name = "nvidia/mit-b0"
>>> model = AutoModelForSemanticSegmentation.from_pretrained(
... pretrained_model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
... )
stevhliu marked this conversation as resolved.
Show resolved Hide resolved
```

<Tip>

If you aren't familiar with finetuning a model with the [`Trainer`], take a look at the basic tutorial [here](../training#finetune-with-trainer)!

</Tip>

Define your training hyperparameters in [`TrainingArguments`]. It is important you don't remove unused columns because this will drop the `image` column. Without the `image` column, you can't create `pixel_values`. Set `remove_unused_columns=False` to prevent this behavior!
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

```py
>>> from transformers import TrainingArguments

>>> training_args = TrainingArguments(
stevhliu marked this conversation as resolved.
Show resolved Hide resolved
... output_dir="./results",
... learning_rate=6e-5,
... num_train_epochs=50,
... per_device_train_batch_size=2,
... per_device_eval_batch_size=2,
... save_total_limit=3,
... evaluation_strategy="steps",
... save_strategy="steps",
... save_steps=20,
... eval_steps=20,
... logging_steps=1,
... eval_accumulation_steps=5,
... remove_unused_columns=False,
... )
```

To evaluate model performance during training, you'll need to create a function to compute and report metrics. For semantic segmentation, you'll typically compute the [mean intersection over union](https://huggingface.co/spaces/evaluate-metric/mean_iou) (IoU). The mean IoU measures the overlapping area between the predicted and ground truth segmentation maps.
stevhliu marked this conversation as resolved.
Show resolved Hide resolved

Load the mean IoU from the 🤗 Evaluate library:

```py
>>> import evaluate

>>> metric = evaluate.load("mean_iou")
```

Then create a function to [`~evaluate.EvaluationModule.compute`] the metrics. Your predictions need to be converted to logits first, and then reshaped to match the size of the labels before you can call [`~evaluate.EvaluationModule.compute`]:

```py
>>> def compute_metrics(eval_pred):
... with torch.no_grad():
... logits, labels = eval_pred
... logits_tensor = torch.from_numpy(logits)
... logits_tensor = nn.functional.interpolate(
... logits_tensor,
... size=labels.shape[-2:],
... mode="bilinear",
... align_corners=False,
... ).argmax(dim=1)

... pred_labels = logits_tensor.detach().cpu().numpy()
... metrics = metric.compute(
... predictions=pred_labels,
... references=labels,
... num_labels=num_labels,
... ignore_index=255,
... reduce_labels=False,
... )
... for key, value in metrics.items():
... if type(value) is np.ndarray:
... metrics[key] = value.tolist()
... return metrics
```

Pass your model, training arguments, datasets, and metrics function to the [`Trainer`]:

```py
>>> from transformers import Trainer

>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=train_ds,
... eval_dataset=test_ds,
... compute_metrics=compute_metrics,
... )
```

Lastly, call [`~Trainer.train`] to finetune your model:

```py
>>> trainer.train()
```
</pt>
</frameworkcontent>

## Inference

Great, now that you've finetuned a model, you can use it for inference!

Load an image for inference:

```py
>>> image = ds[0]["image"]
>>> image
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-image.png" alt="Image of bedroom"/>
</div>

Process the image with a feature extractor and place the `pixel_values` on a GPU:

```py
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use GPU if available, otherwise use a CPU
>>> encoding = feature_extractor(image, return_tensors="pt")
>>> pixel_values = encoding.pixel_values.to(device)
```

Pass your input to the model and return the `logits`:

```py
>>> outputs = model(pixel_values=pixel_values)
>>> logits = outputs.logits.cpu()
```

Next, rescale the logits to the original image size:

```py
>>> upsampled_logits = nn.functional.interpolate(
... logits,
... size=image.size[::-1],
... mode="bilinear",
... align_corners=False,
... )

>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
```

To visualize the results, load the [dataset color palette](https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51) that maps each class to their RGB values. Then you can combine and plot your image and the predicted segmentation map:

```py
>>> import matplotlib.pyplot as plt

>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
>>> palette = np.array(ade_palette())
>>> for label, color in enumerate(palette):
... color_seg[pred_seg == label, :] = color
>>> color_seg = color_seg[..., ::-1] # convert to BGR

>>> img = np.array(image) * 0.5 + color_seg * 0.5 # plot the image with the segmentation map
>>> img = img.astype(np.uint8)

>>> plt.figure(figsize=(15, 10))
>>> plt.imshow(img)
>>> plt.show()
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-preds.png" alt="Image of bedroom overlayed with segmentation map"/>
</div>