This project provides an implementation for the InceptionV3 as described in the Rethinking the Inception Architecture for Computer Vision paper.
The implementation is almost a one-to-one translation of the PyTorch implementation (also see the hub page).
Pre-trained weights for this model can be either downloaded from PyTorch (using torchvision), or from mseitzer/pytorch-fid.
The FID weights provided by mseitzer/pytorch-fid use the legacy version of PyTorch's serialization which is not supported by Burn (or more precisely, by Candle which Burn uses in the background). Therefore, the script download_fid_weights.py
is provided. This script downloads the weights, and re-saves them in the current PyTorch format.
To run the script:
# If no arguments are provided, the weights file will be saved to the default location:
# `~/.cache/inception-v3-burn/pt_inception-2015-12-05-6726825d.pth`
python download_fid_weights.py
# Alternatively, you can provide a custom path.
python download_fid_weights.py --file PATH_TO_FILE
Then, add the model to your dependencies:
[dependencies]
inception-v3-burn = { git = "https://github.com/varonroy/inception-v3-burn", features = ["pretrained"] }
And initialize it using the weights that were prepared in the previous steps.
use inception_v3_burn::model::{
weights::{downloader::InceptionV3PretrainedLoader, WeightsSource},
InceptionV3,
};
fn main() {
type B = burn::backend::NdArray;
let device = burn::backend::ndarray::NdArrayDevice::default();
// If you have saved the model to a location other than the default one,
// replace None, with `Some(<fid-weights-file-path>)`.
let (config, model) = InceptionV3::<B>::pretrained(WeightsSource::fid(None), &device).unwrap();
}
This implementation is licensed under the MIT license.
For the pre-trained weights' licenses, please refer to their original sources: