Skip to content

A suite of semantic segmantion algorithms implemented with Jax and Flax

License

Notifications You must be signed in to change notification settings

ChristianOrr/semantic-segmentation

Repository files navigation

Semantic Segmentation Algorithms

This repository contains a suite of semantic segmentation algorithms implemented from scratch with Jax and Flax. The main aim is to implement the algorithms in the simplest possible way while maintaining stability during training and reducing training time.

Implementation Details

The segmentation models are trained on RGB images from the scene parse 150 dataset [1], which contains 150 classes. The output of the models will have a four dimensional shape (B, H, W, C), where B is the batch size, H is the image height, W is the image width and C is the number of classes. Some alterations were added to the original models to speed up training and make the models more robust, such as adding Group Norm and Dropout layers. The DICE loss is used as the loss function for training and evaluation. The models are saved using the Orbax checkpointer and will be provided on Huggingface once the training has completed.

Algorithms

U-Net

U-Net uses the same ideas from the Fully Convolutional Network (FCN) and improves upon them. The main idea is to use an encoder-decoder architecture with skip connections from the encoder layers to the decoder layers. This provides global and local information to the final segmentation layers, which improves the classification and localization in the predicted segmentation. U-Net has a symetric architecture, giving it the U shape it was named after. It's simpler to implement than FCN and is also very fast. This made U-Net one of the most popular segmentation models today.

PSPNet

PSPNet was designed to solve the lack of global scene understanding faced by FCN. It uses a pyramid pooling module (PPM) combined with a pretrained resnet backbone to extract global context information. PPM pools the feature map extracted from the backbone into feature maps with difference scales. The scaled feature maps are then fused, upsampled and then processed by a final convolutional module, then upsamepled again to extract the segmentation mask. The lowest resolution PPM feature maps will contain the coarsest information, which is ideal for understanding global information. While the highest resolution feature maps will contain local information, which is ideal for localizing the objects.

Installation Requirements

If you have a GPU you can install Jax by running the following first:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

All the requirements are provided below:

pip install datasets
pip install flax
pip install augmax
pip install -qq nest_asyncio
pip install matplotlib
pip install pandas
pip install jupyter
pip install scikit-learn

References

About

A suite of semantic segmantion algorithms implemented with Jax and Flax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages