Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 82 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
The main features of this library are:

- High level API (just two lines to create neural network)
- 7 models architectures for binary and multi class segmentation (including legendary Unet)
- 8 models architectures for binary and multi class segmentation (including legendary Unet)
- 57 available encoders for each architecture
- All encoders have pre-trained weights for faster and better convergence

### Table of content
### 📋 Table of content
1. [Quick start](#start)
2. [Examples](#examples)
3. [Models](#models)
Expand All @@ -31,36 +31,42 @@ The main features of this library are:
8. [Citing](#citing)
9. [License](#license)

### Quick start <a name="start"></a>
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
```python
import segmentation_models_pytorch as smp
### ⏳ Quick start <a name="start"></a>

model = smp.Unet()
```
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
#### 1. Create your first Segmentation model with SMP

Segmentation model is just a PyTorch nn.Module, which can be created as easy as:

```python
model = smp.Unet('resnet34', encoder_weights='imagenet')
import segmentation_models_pytorch as smp

model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pretreined weights for encoder initialization
in_channels=1, # model input channels (1 for grayscale images, 3 for RGB, etc.)
classes=3, # model output channels (number of classes in your dataset)
)
```
- see [table](#architectires) with available model architectures
- see [table](#encoders) with avaliable encoders and its corresponding weights

Change number of output classes in the model:
#### 2. Configure data preprocessing

```python
model = smp.Unet('resnet34', classes=3, activation='softmax')
```
All encoders have pretrained weights. Preparing your data the same way as during weights pretraining may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and **not necessary** in case you train the whole model, not only decoder.

All models have pretrained encoders, so you have to prepare your data the same way as during weights pretraining:
```python
from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
```
### Examples <a name="examples"></a>

Congratulations! You are done! Now you can train your model with your favorite framework!

### 💡 Examples <a name="examples"></a>
- Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb).
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [Ttach](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb)
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb)

### Models <a name="models"></a>
### 📦 Models <a name="models"></a>

#### Architectures <a name="architectires"></a>
- [Unet](https://arxiv.org/abs/1505.04597) and [Unet++](https://arxiv.org/pdf/1807.10165.pdf)
Expand All @@ -72,17 +78,20 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

#### Encoders <a name="encoders"></a>

<details>
<summary>Table with ALL avaliable encoders (click to expand)</summary>

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|resnet18 |imagenet<br>ssl*<br>swsl* |11M |
|resnet18 |imagenet / ssl / swsl |11M |
|resnet34 |imagenet |21M |
|resnet50 |imagenet<br>ssl*<br>swsl* |23M |
|resnet50 |imagenet / ssl / swsl |23M |
|resnet101 |imagenet |42M |
|resnet152 |imagenet |58M |
|resnext50_32x4d |imagenet<br>ssl*<br>swsl* |22M |
|resnext101_32x4d |ssl<br>swsl |42M |
|resnext101_32x8d |imagenet<br>instagram<br>ssl*<br>swsl*|86M |
|resnext101_32x16d |instagram<br>ssl*<br>swsl* |191M |
|resnext50_32x4d |imagenet / ssl / swsl |22M |
|resnext101_32x4d |ssl / swsl |42M |
|resnext101_32x8d |imagenet / instagram / ssl / swsl|86M |
|resnext101_32x16d |instagram / ssl / swsl |191M |
|resnext101_32x32d |instagram |466M |
|resnext101_32x48d |instagram |826M |
|dpn68 |imagenet |11M |
Expand All @@ -109,8 +118,8 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
|densenet169 |imagenet |12M |
|densenet201 |imagenet |18M |
|densenet161 |imagenet |26M |
|inceptionresnetv2 |imagenet<br>imagenet+background |54M |
|inceptionv4 |imagenet<br>imagenet+background |41M |
|inceptionresnetv2 |imagenet / imagenet+background |54M |
|inceptionv4 |imagenet / imagenet+background |41M |
|efficientnet-b0 |imagenet |4M |
|efficientnet-b1 |imagenet |6M |
|efficientnet-b2 |imagenet |7M |
Expand All @@ -121,20 +130,52 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
|efficientnet-b7 |imagenet |63M |
|mobilenet_v2 |imagenet |2M |
|xception |imagenet |22M |
|timm-efficientnet-b0 |imagenet<br>advprop<br>noisy-student|4M |
|timm-efficientnet-b1 |imagenet<br>advprop<br>noisy-student|6M |
|timm-efficientnet-b2 |imagenet<br>advprop<br>noisy-student|7M |
|timm-efficientnet-b3 |imagenet<br>advprop<br>noisy-student|10M |
|timm-efficientnet-b4 |imagenet<br>advprop<br>noisy-student|17M |
|timm-efficientnet-b5 |imagenet<br>advprop<br>noisy-student|28M |
|timm-efficientnet-b6 |imagenet<br>advprop<br>noisy-student|40M |
|timm-efficientnet-b7 |imagenet<br>advprop<br>noisy-student|63M |
|timm-efficientnet-b8 |imagenet<br>advprop |84M |
|timm-efficientnet-b0 |imagenet / advprop / noisy-student|4M |
|timm-efficientnet-b1 |imagenet / advprop / noisy-student|6M |
|timm-efficientnet-b2 |imagenet / advprop / noisy-student|7M |
|timm-efficientnet-b3 |imagenet / advprop / noisy-student|10M |
|timm-efficientnet-b4 |imagenet / advprop / noisy-student|17M |
|timm-efficientnet-b5 |imagenet / advprop / noisy-student|28M |
|timm-efficientnet-b6 |imagenet / advprop / noisy-student|40M |
|timm-efficientnet-b7 |imagenet / advprop / noisy-student|63M |
|timm-efficientnet-b8 |imagenet / advprop |84M |
|timm-efficientnet-l2 |noisy-student |474M |

\* `ssl`, `wsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).

### Models API <a name="api"></a>
</details>

Just commonly used encoders

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|resnet18 |imagenet / ssl / swsl |11M |
|resnet34 |imagenet |21M |
|resnet50 |imagenet / ssl / swsl |23M |
|resnet101 |imagenet |42M |
|resnext50_32x4d |imagenet / ssl / swsl |22M |
|resnext101_32x4d |ssl / swsl |42M |
|resnext101_32x8d |imagenet / instagram / ssl / swsl|86M |
|senet154 |imagenet |113M |
|se_resnext50_32x4d |imagenet |25M |
|se_resnext101_32x4d |imagenet |46M |
|densenet121 |imagenet |6M |
|densenet169 |imagenet |12M |
|densenet201 |imagenet |18M |
|inceptionresnetv2 |imagenet / imagenet+background |54M |
|inceptionv4 |imagenet / imagenet+background |41M |
|mobilenet_v2 |imagenet |2M |
|timm-efficientnet-b0 |imagenet / advprop / noisy-student|4M |
|timm-efficientnet-b1 |imagenet / advprop / noisy-student|6M |
|timm-efficientnet-b2 |imagenet / advprop / noisy-student|7M |
|timm-efficientnet-b3 |imagenet / advprop / noisy-student|10M |
|timm-efficientnet-b4 |imagenet / advprop / noisy-student|17M |
|timm-efficientnet-b5 |imagenet / advprop / noisy-student|28M |
|timm-efficientnet-b6 |imagenet / advprop / noisy-student|40M |
|timm-efficientnet-b7 |imagenet / advprop / noisy-student|63M |


### 🔁 Models API <a name="api"></a>

- `model.encoder` - pretrained backbone to extract features of different spatial resolution
- `model.decoder` - depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
Expand Down Expand Up @@ -176,7 +217,7 @@ model = smp.Unet('resnet34', encoder_depth=4)
```


### Installation <a name="installation"></a>
### 🛠 Installation <a name="installation"></a>
PyPI version:
```bash
$ pip install segmentation-models-pytorch
Expand All @@ -186,12 +227,12 @@ Latest version from source:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
````

### Competitions won with the library
### 🏆 Competitions won with the library

`Segmentation Models` package is widely used in the image segmentation competitions.
[Here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions.

### Contributing
### 🤝 Contributing

##### Run test
```bash
Expand All @@ -202,7 +243,7 @@ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
```

### Citing
### 📝 Citing
```
@misc{Yakubovskiy:2019,
Author = {Pavel Yakubovskiy},
Expand All @@ -214,5 +255,5 @@ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev
}
```

### License <a name="license"></a>
### 🛡️ License <a name="license"></a>
Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)