TF-Keras implementation of TRIQ as described in Transformer for Image Quality Assessment.
- Clone this repository.
- Install required Python packages. The code is developed by PyCharm in Python 3.7. The requirements.txt document is generated by PyCharm, and the code should also be run in latest versions of the packages.
An example of training TRIQ can be seen in train/train_triq.py. Argparser should be used, but the authors prefer to use dictionary with parameters being defined. It is easy to convert to take arguments. In principle, the following parameters can be defined:
args = {}
args['multi_gpu'] = 0 # gpu setting, set to 1 for using multiple GPUs
args['gpu'] = 0 # If having multiple GPUs, specify which GPU to use
args['result_folder'] = r'..\databases\experiments' # Define result path
args['n_quality_levels'] = 5 # Choose between 1 (MOS prediction) and 5 (distribution prediction)
args['transformer_params'] = [2, 32, 8, 64]
args['train_folders'] = # Define folders containing training images
[
r'..\databases\train\koniq_normal',
r'..\databases\train\koniq_small',
r'..\databases\train\live'
]
args['val_folders'] = # Define folders containing testing images
[
r'..\databases\val\koniq_normal',
r'..\databases\val\koniq_small',
r'..\databases\val\live'
]
args['koniq_mos_file'] = r'..\databases\koniq10k_images_scores.csv' # MOS (distribution of scores) file for KonIQ database
args['live_mos_file'] = r'..\databases\live_mos.csv' # MOS (standard distribution of scores) file for LIVE-wild database
args['backbone'] = 'resnet50' # Choose from ['resnet50', 'vgg16']
args['weights'] = r'...\pretrained_weights\resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' # Define the path of ImageNet pretrained weights
args['initial_epoch'] = 0 # Define initial epoch for use in fine-tune
args['lr_base'] = 1e-4 / 2 # Define the back learning rate in warmup and rate decay approach
args['lr_schedule'] = True # Choose between True and False, indicating if learning rate schedule should be used or not
args['batch_size'] = 32 # Batch size, should choose to fit in the GPU memory
args['epochs'] = 120 # Maximal epoch number, can set early stop in the callback or not
args['image_aug'] = True # Choose between True and False, indicating if image augmentation should be used or not
After TRIQ has been trained, and the weights have been stored in h5 file, it can be used to predict image quality with arbitrary sizes,
args = {}
args['n_quality_levels'] = 5
args['backbone'] = 'resnet50'
args['weights'] = r'..\\TRIQ.h5'
model = create_triq_model(n_quality_levels=args['n_quality_levels'],
backbone=args['backbone'],])
model.load_weights(args['weights'])
And then use ModelEvaluation to predict quality of image set.
In the "examples" folder, an example script examples\image_quality_prediction.py is provided to use the trained weights to predict quality of example images. In the "train" folder, an example script train\validation.py is provided to use the trained weights to predict quality of images in folders.
A potential issue is image shape mismatch. For example, if an image is too large, then line 146 in transformer_iqa.py should be changed to increase the pooling size. For example, it can be changed to self.pooling_small = MaxPool2D(pool_size=(4, 4)) or even larger.
This work uses two publicly available databases: KonIQ-10k KonIQ-10k: An ecologically valid database for deep learning of blind image quality assessment by V. Hosu, H. Lin, T. Sziranyi, and D. Saupe; and LIVE-wild Massive online crowdsourced study of subjective and objective picture quality by D. Ghadiyaram, and A.C. Bovik
-
The two databases were merged, and then split to training and testing sets. Please see README in databases for details.
-
Make MOS files (note: do NOT include head line):
For database with score distribution available, the MOS file is like this (koniq format):
image path, voter number of quality scale 1, voter number of quality scale 2, voter number of quality scale 3, voter number of quality scale 4, voter number of quality scale 5, MOS or Z-score 10004473376.jpg,0,0,25,73,7,3.828571429 10007357496.jpg,0,3,45,47,1,3.479166667 10007903636.jpg,1,0,20,73,2,3.78125 10009096245.jpg,0,0,21,75,13,3.926605505
For database with standard deviation available, the MOS file is like this (live format):
image path, standard deviation, MOS or Z-score t1.bmp,18.3762,63.9634 t2.bmp,13.6514,25.3353 t3.bmp,18.9246,48.9366 t4.bmp,18.2414,35.8863
The format of MOS file ('koniq' or 'live') and the format of MOS or Z-score ('mos' or 'z_score') should also be specified in misc/imageset_handler/get_image_scores.
-
In the train script in train/train_triq.py the folders containing training and testing images are provided.
-
Pretrained ImageNet weights can be downloaded (see README in.\pretrained_weights) and pointed to in the train script.
TRIQ has been trained on KonIQ-10k and LIVE-wild databases, and the weights file can be downloaded here.
Other three models are also included in the work. The original implementations of metrics are employed, and they can be found below.
Koncept512 KonIQ-10k: An ecologically valid database for deep learning of blind image quality assessment
SGDNet SGDNet: An end-to-end saliency-guided deep neural network for no-reference image quality assessment
CaHDC End-to-end blind image quality prediction with cascaded deep neural network
We have conducted several experiments to evaluate the performance of TRIQ, please see results.pdf for detailed results.
In case errors/exceptions are encountered, please first check all the paths. After fixing the path isse, please report any errors in Issues.
- To be added
This work is heavily inspired by ViT An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. The module vit_iqa contains implementation of ViT for IQA, and mainly followed the implementation of ViT-PyTorch. Pretrained ViT weights can be downloaded here.