These options apply to all experiments. For options for specific experiments (mask based consistency, augmentation based consistency, ICT, VAT), see below.
--job_desc
: provide a job description/name. For example, running thetrain_seg_semisup_mask_mt.py
with--job_desc=test_a_1
program will save its log file toresults/train_seg_semisup_mask_mt/log_test_a_1.txt
and models and predictions will be saved to the directoryresults/train_seg_semisup_mask_mt/test_a_1
.--dataset
[default=pascal_aug]: select the dataset to train on:, one ofcamvid
,cityscapes
,pascal
,pascal_aug
(Pascal VOC 2012 augmented with SBD),isic2017
--model
[default=mean_teacher]: select the consistency model:mean_teacher
use the Mean Tecaher model of Tarvainen et al.pi
use the Pi-model of Laine et al.
--arch
[default=resnet101_deeplab_imagenet]: select the model architecture:resnet50unet_imagenet
: ResNet-50 based U-net with ImageNet classification pre-trainingresnet101unet_imagenet
: ResNet-101 based U-net with ImageNet classification pre-trainingdensenet161unet
: DenseNet-161 based U-Net, randomly initialiseddensenet161unet_imagenet
: DenseNet-161 based U-Net, with ImageNet classification pre-trainingresnet101_deeplab_coco
: ResNet-101 based DeepLab v2, with CoCo semantic segmentation pre-trainingresnet101_deeplab_imagenet
[default]: ResNet-101 based DeepLab v2, with ImageNet classification pre-trainingresnet101_deeplab_imagenet_mittal_std
: ResNet-101 based DeepLab v2, with ImageNet classification pre-training, using the mean and std-dev used for normalization by Mittal et al.resnet101_deeplabv3_coco
: torchvision ResNet-101 based DeepLab v3 with CoCo semantic segmentation pre-trainingresnet101_deeplabv3_imagenet
: torchvision ResNet-101 based DeepLab v3 with ImageNet classification pre-trainingresnet101_deeplabv3plus_imagenet
: ResNet-101 based DeepLab v3+, with ImageNet classification pre-trainingresnet101_pspnet_imagenet
: ResNet-101 based PSP-net (see Pyramid Scene Parsing Network by Zhao et al.), with ImageNet classification pre-training. To use this architecture you need to install our modified version of MIT CSAIL's semantic-segmentation-pytorch library. Grab thelogits-from_models
branch of https://github.com/Britefury/semantic-segmentation-pytorch
--freeze_bn
: flag to enable freezing of batch-norm layers. Use for DeepLab models, or forresnet50unet_imagenet
if using a batch size of 1--opt_type
[default=adam]: optimizer type; one ofsgd
,adam
[default]--sgd_momentum
[default=0.9]: set momentum if using SGD optimizer--sgd_nesterov
: flag to enable Nesterov momentum if using SGD optimizer--sgd_weight_decay
[default=5e-4]: set weight decay if using SGD optimizer--learning_rate
[default=1e-4]: set learning rate (use3e-5
for DeepLab v2 models,1e-5
for DeepLab v3+,0.1
with SGD optimizer fordensenet161unet_imagenet
forisic2017
dataset)--lr_sched
[default=none]: learning rate scheduler typenone
: no LR schedulestepped
: stepped LR schedule (control with--lr_step_epochs
and--lr_step_gamma
options)cosine
cosine schedulepoly
polynomial schedule, control exponent with--lr_poly_power
--lr_step_epochs
: stepped LR schedule step epochs as a Python list, e.g.--lr_step_epochs=[30,60,80]
will change the learning rate at epochs 30, 60 and 80--lr_step_gamma
[default=0.1]: stepped LR schedule gamma; reduce the larning rate by this factor at each step--lr_poly_power
[default=0.9]: polynomial LR schedule gamma; scale learning rate byp^(1-(iter/max_iters))
--teacher_alpha
[default=0.99]: EMA alpha used to update teacher network when using the mean teacher model--bin_fill_holes
: flag to enable hole filling for foreground class. Only usable for binary segmentation tasks e.g. ISIC 2017 segmentation. Used for ISIC 2017 experiments.--crop_size
: size of crop to extract during training, asH,W
e.g.--crop_size=321,321
. Should be provided.--aug_hflip
: augmentation: enable horizontal flip--aug_vflip
: augmentation: enable vertical flip--aug_hvflip
: augmentation: enable diagonal flip (swap X and Y axes)--aug_scale_hung
: augmentation: enable scaling used by Hung et al. (scale factor chosen randomly between0.5
and1.5
in increments of0.1
).--aug_max_scale
[default=1.0]: augmentation: enable random scale augmentation; scale factor chosen in range[1/aug_max_scale, aug_max_scale]
from log-uniform distribution (overriden by--aug_scale_hung
is used)--aug_scale_non_uniform
: augmentation: enable non-uniform scaling (compatible with both--aug_scale_hung
and--aug_max_scale
)--aug_rot_mag
[default=0.0]: augmentation: random rotation magnitude in degrees; rotate by angle chosen from range[-aug_rot_mag, aug_rot_mag]
(disabled by--aug_scale_hung
)--cons_loss_fn
: consistency loss function:var
[default for all experiments except VAT] squared error between predicted probabilitiesbce
: binary cross entropy, using teacher predictions as targetkld
[default for VAT experiment]: KL-divergencelogits_var
: squared error between predicted pre-softmax logitslogits_smoothl1
[not available for VAT experiment]:sSmooth L1-loss between predicted pre-softmax logits
--cons_weight
[default=1.0, 0.3 for ICT experiment]: consistency loss weight--conf_thresh
[default=0.97]: confidence threshold--conf_per_pixel
: flag to enable applying confidence threshold per pixel, otherwise averages the confidence mask--rampup
[default=0]: Ramp up the consistency loss weight over the specified number of epochs using the sigmoid function specified in Laine et al. Only works with--conf_thresh=0
.--unsup_batch_ratio
[default=1]: for each supervised batch, process this number of unsupervised batches. This proved to be successful for semi-supervised classification in UDA by Xie et al. and FixMatch by Sohn et al.. Tests with CamVid yielded limited success for segmentation. We didn't try this with other datasets though.--num_epochs
[default=300]: number of epochs to train for--iters_per_epoch
[default=-1]: number of iterations per epoch. If-1
is given, it will be the number of mini-batches required to cover the training set--batch_size
[default=10]: the mini-batch size--n_sup
[default=100]: the number of supervised samples to use during training. These will be randomly selected from the training set, using the random seed provided using--split_seed
to initialise the RNG. Alternative, if--split_path
is provided, the firstn_sup
samples will be selected from the array of indices loaded from the specified file--n_unsup
[default=-1]: the number of unsupervised samples to use during training. If-1
is given use all training samples--n_val
[default=-1]: the number of samples used for validation. If the dataset provides separate validation and test sets (e.g. CamVid) this will be ignored. If-1
is the provided validation/test set will be used as the validation set. If a value is provided,n_val
samples will be randomly selected, using--val_seed
to initialise the RNG. If--split_path
is provided, the lastn_val
samples in the index array will be used.--split_seed
[default=12345]: the seed used to initialise the RNG used to select supervised samples--val_seed
[default=131]: the seed used to initialise the RNG used to select validation samples--split_path
: give the path of a pickle (.pkl
) file from which an index array will be loaded. This index array will be used to select supervised and validation samples, rather than an RNG. Validation samples will be taken from the end, supervised samples from the start.--save_preds
: if enabled, after training the predictions for validation and test samples will be saved in the output directory (see--job_desc
)--save_model
: if enabled, after training the model will be saved in the output directory (see--job_desc
)--num_workers
[default=4]: the number of worker processes used to load data batches in the background
These options apply to the train_seg_semisup_mask_mt.py
program
--mask_mode
[default=mix]: masking modezero
: multiply input images by mask, clearing masked regions to zeromix
: use mask to blend pairs of input images
--mask_prop_range
[default=0.5]: mask proportion range; the proportion of the mask with a 1 value will be drawn from this range. Either a single value for a fixed proportion (e.g.0.5
) or a range separated by a colon (e.g.0.0:1.0
).--boxmask_n_boxes
[default=1]: number of boxes to draw into the mask. Note that boxes are XOR'ed with one another--boxmask_fixed_aspect_ratio
: forces all boxes to have an aspect ration that is the same as the image crop (see--crop_size
). Enable this to precisely replicate Cutout or CutMix.--boxmask_by_size
: if enabled, the mask proportion will determine the box edge length, rather than the area it covers--boxmask_outside_bounds
: if enabled, box centres will be selected such that part of the box may lie outside the bounds of the image crop. Enable this to precisely replicate Cutout or CutMix.--boxmask_no_invert
: if enabled, boxes will have a value of 1 against a background of 0, rather than the other way around.
These options apply to the train_seg_semisup_aug_mt.py
program
--aug_offset_range
[default=16]: augmentation: the centres of the two crops extracted from an unsupervised image will be offset from eachother in the range[-aug_offset_range, aug_offset_range]
--aug_free_scale_rot
: augmentation: the two crops can have different scale and rotation; only applies to--aug_max_scale
and--aug_rot_mag
; does not affect--aug_scale_hung
These options apply to the train_seg_semisup_ict.py
program
--ict_alpha
[default=0.1]: alpha value used to determine shape of Beta distribution used to draw blending factors
These options apply to the train_seg_semisup_vat_mt.py
program
--vat_radius
[default=0.5]: the radius that the adversarial perturbation is scaled by. By default this radius isvat_radius * sqrt(H*W*C)
whereH
andW
are the crop size andC
is the number of input channels.--adaptive_vat_radius
: if enabled, scale the VAT radius adaptively according to the content of each unsupervised image; radius isvat_radius * (|dI/dx| + |dI/dy|) * 0.5
, wheredI/dx
anddI/dy
are the horizontal and vertical gradient images respectively--vat_dir_from_student
: if enabled, use the student network to estimate the perturbation direction rather than the teacher (only when using mean teacher model)