From 7958d72e6221944de24051b9c4b83203b50563c7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Sep 2022 19:39:39 -0400 Subject: [PATCH] Update Imagenet example --- examples/ImageNet/Project.toml | 45 +++ examples/ImageNet/README.md | 168 ++++++--- examples/ImageNet/config.jl | 48 +++ examples/ImageNet/data.jl | 112 ++++++ examples/ImageNet/main.jl | 656 +++++++++------------------------ examples/ImageNet/utils.jl | 198 ++++++++++ examples/Project.toml | 23 -- 7 files changed, 699 insertions(+), 551 deletions(-) create mode 100644 examples/ImageNet/Project.toml create mode 100644 examples/ImageNet/config.jl create mode 100644 examples/ImageNet/data.jl create mode 100644 examples/ImageNet/utils.jl diff --git a/examples/ImageNet/Project.toml b/examples/ImageNet/Project.toml new file mode 100644 index 000000000..9a367f7b5 --- /dev/null +++ b/examples/ImageNet/Project.toml @@ -0,0 +1,45 @@ +[deps] +Augmentor = "02898b10-1f73-11ea-317c-6393d7073e15" +Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" +FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" +Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +JLSO = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11" +JpegTurbo = "b835a17e-a41a-41e7-81f0-2f016b05efe0" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleConfig = "f2d95530-262a-480f-aff0-1c0431e662a7" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Augmentor = "0.6" +Boltz = "0.1" +CUDA = "3" +Configurations = "0.17" +FLoops = "0.2" +FluxMPI = "0.6" +Formatting = "0.4" +Functors = "0.2, 0.3" +Images = "0.24, 0.25" +JLSO = "2" +JpegTurbo = "0.1" +Lux = "0.4" +MLUtils = "0.2.10" +NNlib = "0.8" +OneHotArrays = "0.1" +Optimisers = "0.2" +Setfield = "0.8.2" +SimpleConfig = "0.1" +Zygote = "0.6" diff --git a/examples/ImageNet/README.md b/examples/ImageNet/README.md index 5b3d9c20c..8aa8f34f4 100644 --- a/examples/ImageNet/README.md +++ b/examples/ImageNet/README.md @@ -1,82 +1,142 @@ # Imagenet Training using Lux -This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset. +This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on +the ImageNet dataset. ## Requirements * Install [julia](https://julialang.org/) * In the Julia REPL instantiate the `Project.toml` in the parent directory * Download the ImageNet dataset from http://www.image-net.org/ - - Then, move and extract the training and validation images to labeled subfolders, using [the following shell script](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh) + - Then, move and extract the training and validation images to labeled subfolders, using + [this shell script](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh) ## Training -To train a model, run `main.jl` with the desired model architecture and the path to the ImageNet dataset: +To train a model, run `main.jl` with the necessary parameters. See +[Boltz documentation](http://lux.csail.mit.edu/stable/lib/Boltz/) for the model +configuration. ```bash -julia --project=.. -t 8 main.jl --arch ResNet18 [imagenet-folder with train and val folders] -``` - -The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG: - -```bash -julia --project=.. -t 8 main.jl --arch AlexNet --learning-rate 0.01 [imagenet-folder with train and val folders] +julia --project=examples/ImageNet -t 4 examples/ImageNet/main.jl\ + --cfg.dataset.data_root=/home/avik-pal/data/ImageNet/\ + --cfg.dataset.train_batchsize=256 --cfg.dataset.eval_batchsize=256\ + --cfg.optimizer.learning_rate=0.5 + +julia --project=examples/ImageNet -t 4 examples/ImageNet/main.jl\ + --cfg.model.name=alexnet --cfg.model.arch=alexnet\ + --cfg.dataset.data_root=/home/avik-pal/data/ImageNet/\ + --cfg.dataset.train_batchsize=256 --cfg.dataset.eval_batchsize=256\ + --cfg.optimizer.learning_rate=0.01 ``` ## Distributed Data Parallel Training -Setup [MPI.jl](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) preferably with the system MPI. Set `FLUXMPI_DISABLE_CUDAMPI_SUPPORT=true` to disable communication via CuArrays (note that this will lead to a very high communication bottleneck). +Setup [MPI.jl](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) +preferably with the system MPI. Set `FLUXMPI_DISABLE_CUDAMPI_SUPPORT=true` to disable +communication via CuArrays (note that this might lead to a very high communication +bottleneck). -**Learning Rate**: Remember to linearly scale the learning-rate based on the number of processes you are using. +!!! tip "Learning Rate" -**NOTE**: If using CUDA-aware MPI you need to disable the default CUDA allocator by `export JULIA_CUDA_MEMORY_POOL=none`. This might slow down your code slightly but will prevent any sudden segfaults which occur without setting this parameter. + Remember to linearly scale the learning-rate based on the number of processes you are + using. +!!! note -## Usage + If using CUDA-aware MPI you need to disable the default CUDA allocator by + `export JULIA_CUDA_MEMORY_POOL=none`. This might slow down your code slightly but will + prevent any sudden segfaults which occur without setting this parameter. ```bash -usage: main.jl [--arch ARCH] [--epochs EPOCHS] - [--start-epoch START-EPOCH] [--batch-size BATCH-SIZE] - [--learning-rate LEARNING-RATE] [--momentum MOMENTUM] - [--weight-decay WEIGHT-DECAY] [--print-freq PRINT-FREQ] - [--resume RESUME] [--evaluate] [--pretrained] - [--seed SEED] [--distributed] [-h] data +mpiexecjl -np 4 julia --project=examples/ImageNet -t 4 examples/ImageNet/main.jl\ + --cfg.dataset.data_root=/home/avik-pal/data/ImageNet/\ + --cfg.dataset.train_batchsize=256 --cfg.dataset.eval_batchsize=256\ + --cfg.optimizer.learning_rate=0.5 +``` -Lux ImageNet Training -positional arguments: - data path to dataset +## Usage + +```bash +usage: main.jl [--cfg.seed CFG.SEED] [--cfg.model.name CFG.MODEL.NAME] + [--cfg.model.arch CFG.MODEL.ARCH] + [--cfg.model.pretrained CFG.MODEL.PRETRAINED] + [--cfg.optimizer.name CFG.OPTIMIZER.NAME] + [--cfg.optimizer.learning_rate CFG.OPTIMIZER.LEARNING_RATE] + [--cfg.optimizer.nesterov CFG.OPTIMIZER.NESTEROV] + [--cfg.optimizer.momentum CFG.OPTIMIZER.MOMENTUM] + [--cfg.optimizer.weight_decay CFG.OPTIMIZER.WEIGHT_DECAY] + [--cfg.optimizer.scheduler.name CFG.OPTIMIZER.SCHEDULER.NAME] + [--cfg.optimizer.scheduler.cycle_length CFG.OPTIMIZER.SCHEDULER.CYCLE_LENGTH] + [--cfg.optimizer.scheduler.damp_factor CFG.OPTIMIZER.SCHEDULER.DAMP_FACTOR] + [--cfg.optimizer.scheduler.lr_step CFG.OPTIMIZER.SCHEDULER.LR_STEP] + [--cfg.optimizer.scheduler.lr_step_decay CFG.OPTIMIZER.SCHEDULER.LR_STEP_DECAY] + [--cfg.train.total_steps CFG.TRAIN.TOTAL_STEPS] + [--cfg.train.evaluate_every CFG.TRAIN.EVALUATE_EVERY] + [--cfg.train.resume CFG.TRAIN.RESUME] + [--cfg.train.evaluate CFG.TRAIN.EVALUATE] + [--cfg.train.checkpoint_dir CFG.TRAIN.CHECKPOINT_DIR] + [--cfg.train.log_dir CFG.TRAIN.LOG_DIR] + [--cfg.train.expt_subdir CFG.TRAIN.EXPT_SUBDIR] + [--cfg.train.expt_id CFG.TRAIN.EXPT_ID] + [--cfg.train.print_frequency CFG.TRAIN.PRINT_FREQUENCY] + [--cfg.dataset.data_root CFG.DATASET.DATA_ROOT] + [--cfg.dataset.eval_batchsize CFG.DATASET.EVAL_BATCHSIZE] + [--cfg.dataset.train_batchsize CFG.DATASET.TRAIN_BATCHSIZE] + [-h] optional arguments: - --arch ARCH model architectures: VGG19, ResNet50, - GoogLeNet, ResNeXt152, DenseNet201, - MobileNetv3_small, ResNet34, ResNet18, - DenseNet121, ResNet101, VGG13_BN, DenseNet169, - MobileNetv1, VGG11_BN, DenseNet161, - MobileNetv3_large, VGG11, VGG19_BN, VGG16_BN, - VGG16, ResNeXt50, AlexNet, VGG13, ResNeXt101, - MobileNetv2, ConvMixer or ResNet152 (default: - "ResNet18") - --epochs EPOCHS number of total epochs to run (type: Int64, - default: 90) - --start-epoch START-EPOCH - manual epoch number (useful on restarts) - (type: Int64, default: 0) - --batch-size BATCH-SIZE - mini-batch size, this is the total batch size - across all GPUs (type: Int64, default: 256) - --learning-rate LEARNING-RATE - initial learning rate (type: Float32, default: - 0.1) - --momentum MOMENTUM momentum (type: Float32, default: 0.9) - --weight-decay WEIGHT-DECAY - weight decay (type: Float32, default: 0.0001) - --print-freq PRINT-FREQ - print frequency (type: Int64, default: 10) - --resume RESUME resume from checkpoint (default: "") - --evaluate evaluate model on validation set - --pretrained use pre-trained model - --seed SEED seed for initializing training. (type: Int64, - default: 0) + --cfg.seed CFG.SEED (type: Int64, default: 12345) + --cfg.model.name CFG.MODEL.NAME + (default: "resnet") + --cfg.model.arch CFG.MODEL.ARCH + (default: "resnet18") + --cfg.model.pretrained CFG.MODEL.PRETRAINED + (type: Bool, default: false) + --cfg.optimizer.name CFG.OPTIMIZER.NAME + (default: "adam") + --cfg.optimizer.learning_rate CFG.OPTIMIZER.LEARNING_RATE + (type: Float32, default: 0.01) + --cfg.optimizer.nesterov CFG.OPTIMIZER.NESTEROV + (type: Bool, default: false) + --cfg.optimizer.momentum CFG.OPTIMIZER.MOMENTUM + (type: Float32, default: 0.0) + --cfg.optimizer.weight_decay CFG.OPTIMIZER.WEIGHT_DECAY + (type: Float32, default: 0.0) + --cfg.optimizer.scheduler.name CFG.OPTIMIZER.SCHEDULER.NAME + (default: "step") + --cfg.optimizer.scheduler.cycle_length CFG.OPTIMIZER.SCHEDULER.CYCLE_LENGTH + (type: Int64, default: 50000) + --cfg.optimizer.scheduler.damp_factor CFG.OPTIMIZER.SCHEDULER.DAMP_FACTOR + (type: Float32, default: 1.2) + --cfg.optimizer.scheduler.lr_step CFG.OPTIMIZER.SCHEDULER.LR_STEP + (type: Vector{Int64}, default: [100000, 250000, 500000]) + --cfg.optimizer.scheduler.lr_step_decay CFG.OPTIMIZER.SCHEDULER.LR_STEP_DECAY + (type: Float32, default: 0.1) + --cfg.train.total_steps CFG.TRAIN.TOTAL_STEPS + (type: Int64, default: 800000) + --cfg.train.evaluate_every CFG.TRAIN.EVALUATE_EVERY + (type: Int64, default: 10000) + --cfg.train.resume CFG.TRAIN.RESUME + (default: "") + --cfg.train.evaluate CFG.TRAIN.EVALUATE + (type: Bool, default: false) + --cfg.train.checkpoint_dir CFG.TRAIN.CHECKPOINT_DIR + (default: "checkpoints") + --cfg.train.log_dir CFG.TRAIN.LOG_DIR + (default: "logs") + --cfg.train.expt_subdir CFG.TRAIN.EXPT_SUBDIR + (default: "") + --cfg.train.expt_id CFG.TRAIN.EXPT_ID + (default: "") + --cfg.train.print_frequency CFG.TRAIN.PRINT_FREQUENCY + (type: Int64, default: 100) + --cfg.dataset.data_root CFG.DATASET.DATA_ROOT + (default: "") + --cfg.dataset.eval_batchsize CFG.DATASET.EVAL_BATCHSIZE + (type: Int64, default: 64) + --cfg.dataset.train_batchsize CFG.DATASET.TRAIN_BATCHSIZE + (type: Int64, default: 64) -h, --help show this help message and exit -``` \ No newline at end of file +``` diff --git a/examples/ImageNet/config.jl b/examples/ImageNet/config.jl new file mode 100644 index 000000000..5e509ca65 --- /dev/null +++ b/examples/ImageNet/config.jl @@ -0,0 +1,48 @@ +@option struct ModelConfig + name::String = "resnet" + arch::String = "resnet18" + pretrained::Bool = false +end + +@option struct SchedulerConfig + name::String = "step" + cycle_length::Int = 50000 + damp_factor::Float32 = 1.2f0 + lr_step::Vector{Int64} = [100000, 250000, 500000] + lr_step_decay::Float32 = 0.1f0 +end + +@option struct OptimizerConfig + name::String = "adam" + learning_rate::Float32 = 0.01f0 + nesterov::Bool = false + momentum::Float32 = 0.0f0 + weight_decay::Float32 = 0.0f0 + scheduler::SchedulerConfig = SchedulerConfig() +end + +@option struct TrainConfig + total_steps::Int = 800000 + evaluate_every::Int = 10000 + resume::String = "" + evaluate::Bool = false + checkpoint_dir::String = "checkpoints" + log_dir::String = "logs" + expt_subdir::String = "" + expt_id::String = "" + print_frequency::Int = 100 +end + +@option struct DatasetConfig + data_root::String = "" + eval_batchsize::Int = 64 + train_batchsize::Int = 64 +end + +@option struct ExperimentConfig + seed::Int = 12345 + model::ModelConfig = ModelConfig() + optimizer::OptimizerConfig = OptimizerConfig() + train::TrainConfig = TrainConfig() + dataset::DatasetConfig = DatasetConfig() +end diff --git a/examples/ImageNet/data.jl b/examples/ImageNet/data.jl new file mode 100644 index 000000000..e6dee8c67 --- /dev/null +++ b/examples/ImageNet/data.jl @@ -0,0 +1,112 @@ +# DataLoading +struct ImageDataset + image_files::Any + labels::Any + mapping::Any + augmentation_pipeline::Any + normalization_parameters::Any +end + +function ImageDataset(folder::String, augmentation_pipeline, normalization_parameters) + ulabels = readdir(folder) + label_dirs = joinpath.((folder,), ulabels) + @assert length(label_dirs)==1000 "There should be 1000 subdirectories in $folder" + + classes = readlines(joinpath(@__DIR__, "synsets.txt")) + mapping = Dict(z => i for (i, z) in enumerate(ulabels)) + + istrain = endswith(folder, r"train|train/") + + if istrain + image_files = vcat(map((x, y) -> joinpath.((x,), y), label_dirs, + readdir.(label_dirs))...) + + remove_files = [ + "n01739381_1309.JPEG", + "n02077923_14822.JPEG", + "n02447366_23489.JPEG", + "n02492035_15739.JPEG", + "n02747177_10752.JPEG", + "n03018349_4028.JPEG", + "n03062245_4620.JPEG", + "n03347037_9675.JPEG", + "n03467068_12171.JPEG", + "n03529860_11437.JPEG", + "n03544143_17228.JPEG", + "n03633091_5218.JPEG", + "n03710637_5125.JPEG", + "n03961711_5286.JPEG", + "n04033995_2932.JPEG", + "n04258138_17003.JPEG", + "n04264628_27969.JPEG", + "n04336792_7448.JPEG", + "n04371774_5854.JPEG", + "n04596742_4225.JPEG", + "n07583066_647.JPEG", + "n13037406_4650.JPEG", + "n02105855_2933.JPEG", + ] + remove_files = joinpath.((folder,), + joinpath.(first.(rsplit.(remove_files, "_", limit=2)), + remove_files)) + + image_files = [setdiff(Set(image_files), Set(remove_files))...] + + labels = [mapping[x] for x in map(x -> x[2], rsplit.(image_files, "/", limit=3))] + else + vallist = hcat(split.(readlines(joinpath(@__DIR__, "val_list.txt")))...) + labels = parse.(Int, vallist[2, :]) .+ 1 + filenames = [joinpath(classes[l], vallist[1, i]) for (i, l) in enumerate(labels)] + image_files = joinpath.((folder,), filenames) + idxs = findall(isfile, image_files) + image_files = image_files[idxs] + labels = labels[idxs] + end + + return ImageDataset(image_files, labels, mapping, augmentation_pipeline, + normalization_parameters) +end + +function Base.getindex(data::ImageDataset, i::Int) + img = Images.load(data.image_files[i]) + img = augment(img, data.augmentation_pipeline) + cimg = channelview(img) + if ndims(cimg) == 2 + cimg = reshape(cimg, 1, size(cimg, 1), size(cimg, 2)) + cimg = vcat(cimg, cimg, cimg) + end + img = Float32.(permutedims(cimg, (3, 2, 1))) + img = (img .- data.normalization_parameters.mean) ./ data.normalization_parameters.std + return img, onehot(data.labels[i], 1:1000) +end + +Base.length(data::ImageDataset) = length(data.image_files) + +function construct(cfg::DatasetConfig) + normalization_parameters = (mean=reshape([0.485f0, 0.456f0, 0.406f0], 1, 1, 3), + std=reshape([0.229f0, 0.224f0, 0.225f0], 1, 1, 3)) + train_data_augmentation = Resize(256, 256) |> FlipX(0.5) |> RCropSize(224, 224) + val_data_augmentation = Resize(256, 256) |> CropSize(224, 224) + train_dataset = ImageDataset(joinpath(cfg.data_root, "train"), train_data_augmentation, + normalization_parameters) + val_dataset = ImageDataset(joinpath(cfg.data_root, "val"), val_data_augmentation, + normalization_parameters) + if is_distributed() + train_dataset = DistributedDataContainer(train_dataset) + val_dataset = DistributedDataContainer(val_dataset) + end + + train_data = BatchView(shuffleobs(train_dataset); + batchsize=cfg.train_batchsize ÷ total_workers(), partial=false, + collate=true) + + val_data = BatchView(val_dataset; batchsize=cfg.eval_batchsize ÷ total_workers(), + partial=true, collate=true) + + train_iter = Iterators.cycle(MLUtils.eachobsparallel(train_data; executor=ThreadedEx(), + buffer=true)) + + val_iter = MLUtils.eachobsparallel(val_data; executor=ThreadedEx(), buffer=true) + + return train_iter, val_iter +end diff --git a/examples/ImageNet/main.jl b/examples/ImageNet/main.jl index 032a1696d..9b1b0764b 100644 --- a/examples/ImageNet/main.jl +++ b/examples/ImageNet/main.jl @@ -1,91 +1,49 @@ # Imagenet training script based on https://github.com/pytorch/examples/blob/main/imagenet/main.py -using ArgParse # Parse Arguments from Commandline using Augmentor # Image Augmentation +using Boltz # Computer Vision Models +using Configurations # Experiment Configurations using CUDA # GPUs <3 -using DataLoaders # Pytorch like DataLoaders using Dates # Printing current time -using Lux # Neural Network Framework using FluxMPI # Distibuted Training +using FLoops using Formatting # Pretty Printing using Functors # Parameter Manipulation +using JLSO # Serialization using Images # Image Processing -using Metalhead # Image Classification Models -using MLDataUtils # Shuffling and Splitting Data +using Lux # Neural Network Framework +using MLUtils # DataLoaders using NNlib # Neural Network Backend +using OneHotArrays # One Hot Arrays using Optimisers # Collection of Gradient Based Optimisers -using ParameterSchedulers # Collection of Schedulers for Parameter Updates using Random # Make things less Random -using Serialization # Serialize Models +using SimpleConfig # Extends Configurations.jl using Setfield # Easy Parameter Manipulation using Statistics # Statistics using Zygote # Our AD Engine -import Flux: OneHotArray, onecold, onehot, onehotbatch # Only being used for OneHotArrays -import DataLoaders: LearnBase # Extending Datasets -import MLUtils - # Distributed Training FluxMPI.Init(; verbose=true) -CUDA.allowscalar(false) - -# unsafe_free OneHotArrays -CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) - -# Image Classification Models -VGG11_BN(args...; kwargs...) = VGG11(args...; batchnorm=true, kwargs...) -VGG13_BN(args...; kwargs...) = VGG13(args...; batchnorm=true, kwargs...) -VGG16_BN(args...; kwargs...) = VGG16(args...; batchnorm=true, kwargs...) -VGG19_BN(args...; kwargs...) = VGG19(args...; batchnorm=true, kwargs...) -MobileNetv3_small(args...; kwargs...) = MobileNetv3(:small, args...; kwargs...) -MobileNetv3_large(args...; kwargs...) = MobileNetv3(:large, args...; kwargs...) -ResNeXt50(args...; kwargs...) = ResNeXt(50, args...; kwargs...) -ResNeXt101(args...; kwargs...) = ResNeXt(101, args...; kwargs...) -ResNeXt152(args...; kwargs...) = ResNeXt(152, args...; kwargs...) - -AVAILABLE_IMAGENET_MODELS = [ - AlexNet, - VGG11, - VGG13, - VGG16, - VGG19, - VGG11_BN, - VGG13_BN, - VGG16_BN, - VGG19_BN, - ResNet18, - ResNet34, - ResNet50, - ResNet101, - ResNet152, - ResNeXt50, - ResNeXt101, - ResNeXt152, - GoogLeNet, - DenseNet121, - DenseNet161, - DenseNet169, - DenseNet201, - MobileNetv1, - MobileNetv2, - MobileNetv3_small, - MobileNetv3_large, - ConvMixer, -] - -IMAGENET_MODELS_DICT = Dict(string(model) => model for model in AVAILABLE_IMAGENET_MODELS) - -function get_model(model_name::String, models_dict::Dict, rng, args...; warmup=true, - kwargs...) - model = Lux.transform(models_dict[model_name](args...; kwargs...).layers) - ps, st = Lux.setup(rng, model) .|> gpu - if warmup - # Warmup for compilation - x__ = randn(rng, Float32, 224, 224, 3, 1) |> gpu - y__ = onehotbatch([1], 1:1000) |> gpu - should_log() && println("$(now()) ==> staring `$model_name` warmup...") - model(x__, ps, st) - should_log() && println("$(now()) ==> forward pass warmup completed") + +# Experiment Configuration +includet("config.jl") +# Utility Functions +includet("utils.jl") +# DataLoading +includet("data.jl") + +function construct(rng::Random.AbstractRNG, cfg::ModelConfig, ecfg::ExperimentConfig) + model, ps, st = getfield(Boltz, Symbol(cfg.name))(Symbol(cfg.arch); cfg.pretrained) + ps, st = (ps, st) .|> gpu + + # Warmup for compilation + x__ = randn(rng, Float32, 224, 224, 3, 1) |> gpu + y__ = onehotbatch([1], 1:1000) |> gpu + should_log() && println("$(now()) ==> staring `$(cfg.arch)` warmup...") + model(x__, ps, st) + should_log() && println("$(now()) ==> forward pass warmup completed") + + if !ecfg.train.evaluate (l, _, _), back = Zygote.pullback(p -> logitcrossentropyloss(x__, y__, model, p, st), ps) back((one(l), nothing, nothing)) @@ -98,283 +56,54 @@ function get_model(model_name::String, models_dict::Dict, rng, args...; warmup=t should_log() && println("$(now()) ===> models synced across all ranks") end - return model, ps, st -end - -# Parse Training Arguments -function parse_commandline_arguments() - parse_settings = ArgParseSettings("Lux ImageNet Training") - @add_arg_table! parse_settings begin - """ - --arch - """ - default = "ResNet18" - range_tester = x -> x ∈ keys(IMAGENET_MODELS_DICT) - help = "model architectures: " * join(keys(IMAGENET_MODELS_DICT), ", ", " or ") - """ - --epochs - """ - help = "number of total epochs to run" - arg_type = Int - default = 90 - """ - --start-epoch - """ - help = "manual epoch number (useful on restarts)" - arg_type = Int - default = 0 - """ - --batch-size - """ - help = "mini-batch size, this is the total batch size across all GPUs" - arg_type = Int - default = 256 - """ - --learning-rate - """ - help = "initial learning rate" - arg_type = Float32 - default = 0.1f0 - """ - --momentum - """ - help = "momentum" - arg_type = Float32 - default = 0.9f0 - """ - --weight-decay - """ - help = "weight decay" - arg_type = Float32 - default = 1.0f-4 - """ - --print-freq - """ - help = "print frequency" - arg_type = Int - default = 10 - """ - --resume - """ - help = "resume from checkpoint" - arg_type = String - default = "" - """ - --evaluate - """ - help = "evaluate model on validation set" - action = :store_true - """ - --pretrained - """ - help = "use pre-trained model" - action = :store_true - """ - --seed - """ - help = "seed for initializing training. " - arg_type = Int - default = 0 - """ - data - """ - help = "path to dataset" - required = true - end - - return parse_args(parse_settings) -end - -# Loss Function -logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) - -function logitcrossentropyloss(x, y, model, ps, st) - ŷ, st_ = model(x, ps, st) - return logitcrossentropyloss(ŷ, y), ŷ, st_ -end - -# Optimisers / Parameter Schedulers -function update_lr(st::ST, eta) where {ST} - if hasfield(ST, :eta) - @set! st.eta = eta - end - return st -end -update_lr(st::Optimisers.OptimiserChain, eta) = update_lr.(st.opts, eta) -function update_lr(st::Optimisers.Leaf, eta) - @set! st.rule = update_lr(st.rule, eta) -end -update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) - -# Accuracy -function accuracy(ŷ, y, topk=(1,)) - maxk = maximum(topk) - - pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev=true) - true_labels = onecold(y) - - accuracies = Vector{Float32}(undef, length(topk)) - - for (i, k) in enumerate(topk) - accuracies[i] = sum(map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, - true_labels)) - end - - return accuracies .* 100 ./ size(y, ndims(y)) + return (model, ps, st) end -# Distributed Utils -is_distributed() = FluxMPI.Initialized() && total_workers() > 1 -should_log() = !FluxMPI.Initialized() || local_rank() == 0 - -# Checkpointing -function save_checkpoint(state, is_best, filename="checkpoint.pth.tar") - if should_log() - serialize(filename, state) - if is_best - cp(filename, "model_best.pth.tar"; force=true) +function construct(cfg::OptimizerConfig) + if cfg.name == "adam" + opt = Adam(cfg.learning_rate) + elseif cfg.name == "sgd" + if cfg.nesterov + opt = Nesterov(cfg.learning_rate, cfg.momentum) + elseif cfg.momentum == 0 + opt = Descent(cfg.learning_rate) + else + opt = Momentum(cfg.learning_rate, cfg.momentum) end - end -end - -# DataLoading -struct ImageDataset - image_files::Any - labels::Any - mapping::Any - augmentation_pipeline::Any - normalization_parameters::Any -end - -function ImageDataset(folder::String, augmentation_pipeline, normalization_parameters) - ulabels = readdir(folder) - label_dirs = joinpath.((folder,), ulabels) - @assert length(label_dirs)==1000 "There should be 1000 subdirectories in $folder" - - classes = readlines(joinpath(@__DIR__, "synsets.txt")) - mapping = Dict(z => i for (i, z) in enumerate(ulabels)) - - istrain = endswith(folder, r"train|train/") - - if istrain - image_files = vcat(map((x, y) -> joinpath.((x,), y), label_dirs, - readdir.(label_dirs))...) - - remove_files = [ - "n01739381_1309.JPEG", - "n02077923_14822.JPEG", - "n02447366_23489.JPEG", - "n02492035_15739.JPEG", - "n02747177_10752.JPEG", - "n03018349_4028.JPEG", - "n03062245_4620.JPEG", - "n03347037_9675.JPEG", - "n03467068_12171.JPEG", - "n03529860_11437.JPEG", - "n03544143_17228.JPEG", - "n03633091_5218.JPEG", - "n03710637_5125.JPEG", - "n03961711_5286.JPEG", - "n04033995_2932.JPEG", - "n04258138_17003.JPEG", - "n04264628_27969.JPEG", - "n04336792_7448.JPEG", - "n04371774_5854.JPEG", - "n04596742_4225.JPEG", - "n07583066_647.JPEG", - "n13037406_4650.JPEG", - "n02105855_2933.JPEG", - ] - remove_files = joinpath.((folder,), - joinpath.(first.(rsplit.(remove_files, "_", limit=2)), - remove_files)) - - image_files = [setdiff(Set(image_files), Set(remove_files))...] - - labels = [mapping[x] for x in map(x -> x[2], rsplit.(image_files, "/", limit=3))] else - vallist = hcat(split.(readlines(joinpath(@__DIR__, "val_list.txt")))...) - labels = parse.(Int, vallist[2, :]) .+ 1 - filenames = [joinpath(classes[l], vallist[1, i]) for (i, l) in enumerate(labels)] - image_files = joinpath.((folder,), filenames) - idxs = findall(isfile, image_files) - image_files = image_files[idxs] - labels = labels[idxs] + throw(ArgumentError("unknown value for `optimizer` = $(cfg.optimizer). Supported " * + "options are: `adam` and `sgd`.")) end - return ImageDataset(image_files, labels, mapping, augmentation_pipeline, - normalization_parameters) -end - -LearnBase.nobs(data::ImageDataset) = length(data.image_files) - -function LearnBase.getobs(data::ImageDataset, i::Int) - img = Images.load(data.image_files[i]) - img = augment(img, data.augmentation_pipeline) - cimg = channelview(img) - if ndims(cimg) == 2 - cimg = reshape(cimg, 1, size(cimg, 1), size(cimg, 2)) - cimg = vcat(cimg, cimg, cimg) + if cfg.weight_decay != 0 + opt = OptimiserChain(opt, WeightDecay(cfg.weight_decay)) end - img = Float32.(permutedims(cimg, (3, 2, 1))) - img = (img .- data.normalization_parameters.mean) ./ data.normalization_parameters.std - return img, onehot(data.labels[i], 1:1000) -end - -MLUtils.numobs(data::ImageDataset) = length(data.image_files) -MLUtils.getobs(data::ImageDataset, i::Int) = LearnBase.getobs(data, i) - -## DataLoaders doesn't yet work with MLUtils -LearnBase.nobs(data::DistributedDataContainer) = MLUtils.numobs(data) - -LearnBase.getobs(data::DistributedDataContainer, i::Int) = MLUtils.getobs(data, i) - -# Tracking -Base.@kwdef mutable struct AverageMeter - fmtstr::Any - val::Float64 = 0.0 - sum::Float64 = 0.0 - count::Int = 0 - average::Float64 = 0 -end - -function AverageMeter(name::String, fmt::String) - fmtstr = FormatExpr("$name {1:$fmt} ({2:$fmt})") - return AverageMeter(; fmtstr=fmtstr) -end - -function update!(meter::AverageMeter, val, n::Int) - meter.val = val - meter.sum += val * n - meter.count += n - meter.average = meter.sum / meter.count - return meter.average -end - -print_meter(meter::AverageMeter) = printfmt(meter.fmtstr, meter.val, meter.average) - -struct ProgressMeter{N} - batch_fmtstr::Any - meters::NTuple{N, AverageMeter} -end + if cfg.scheduler.name == "cosine" + scheduler = CosineAnnealSchedule(cfg.learning_rate, cfg.learning_rate / 100, + cfg.scheduler.cycle_length; + dampen=cfg.scheduler.damp_factor) + elseif cfg.scheduler.name == "constant" + scheduler = ConstantSchedule(cfg.learning_rate) + elseif cfg.scheduler.name == "step" + scheduler = Step(cfg.learning_rate, cfg.scheduler.lr_step_decay, + cfg.scheduler.lr_step) + else + throw(ArgumentError("unknown value for `lr_scheduler` = $(cfg.scheduler.name). " * + "Supported options are: `constant`, `step` and `cosine`.")) + end -function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} - fmt = "%" * string(length(string(num_batches))) * "d" - prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" - batch_fmtstr = generate_formatter("$prefix[$fmt/" * sprintf1(fmt, num_batches) * "]") - return ProgressMeter{N}(batch_fmtstr, meters) + return opt, scheduler end -function print_meter(meter::ProgressMeter, batch::Int) - base_str = meter.batch_fmtstr(batch) - print(base_str) - foreach(x -> (print("\t"); print_meter(x)), meter.meters[1:end]) - return println() +function loss_function(model, ps, st, (x, y)) + y_pred, st_ = model(x, ps, st) + loss = logitcrossentropy(y_pred, y) + return (loss, st_, (; y_pred)) end # Validation -function validate(val_loader, model, ps, st, args) +function validate(val_loader, model, ps, st, step, total_steps) batch_time = AverageMeter("Batch Time", "6.3f") data_time = AverageMeter("Data Time", "6.3f") forward_time = AverageMeter("Forward Pass Time", "6.3f") @@ -382,44 +111,100 @@ function validate(val_loader, model, ps, st, args) top1 = AverageMeter("Acc@1", "6.2f") top5 = AverageMeter("Acc@5", "6.2f") - progress = ProgressMeter(length(val_loader), + progress = ProgressMeter(total_steps, (batch_time, data_time, forward_time, losses, top1, top5), "Val:") st_ = Lux.testmode(st) t = time() - for (i, (x, y)) in enumerate(CuIterator(val_loader)) + for (i, (x, y)) in enumerate(val_loader) + x = x |> gpu + y = y |> gpu t_data, t = time() - t, time() + bsize = size(x, ndims(x)) + # Compute Output - ŷ, st_ = model(x, ps, st_) - loss = logitcrossentropyloss(ŷ, y) + y_pred, st_ = model(x, ps, st_) + loss = logitcrossentropyloss(y_pred, y) t_forward = time() - t # Metrics - acc1, acc5 = accuracy(cpu(ŷ), cpu(y), (1, 5)) - update!(top1, acc1, size(x, ndims(x))) - update!(top5, acc5, size(x, ndims(x))) - update!(losses, loss, size(x, ndims(x))) + acc1, acc5 = accuracy(cpu(y_pred), cpu(y), (1, 5)) + top1(acc1, bsize) + top5(acc5, bsize) + losses(loss, bsize) # Measure Elapsed Time - update!(data_time, t_data, size(x, ndims(x))) - update!(forward_time, t_forward, size(x, ndims(x))) - update!(batch_time, t_data + t_forward, size(x, ndims(x))) - - # Print Progress - if i % args["print-freq"] == 0 || i == length(val_loader) - should_log() && print_meter(progress, i) - end + data_time(t_data, bsize) + forward_time(t_forward, bsize) + batch_time(t_data + t_forward, bsize) t = time() end - return top1.average, top5.average, losses.average + should_log() && print_meter(progress, step) + + return top1.average end -# Training -function train(train_loader, model, ps, st, optimiser_state, epoch, args) +# Main Function +function main(cfg::ExperimentConfig) + best_acc1 = 0 + + # Seeding + rng = get_prng(cfg.seed) + + # Model Construction + if should_log() + if cfg.model.pretrained + println("$(now()) => using pre-trained model `$(cfg.model.arch)`") + else + println("$(now()) => creating model `$(cfg.model.arch)`") + end + end + model, ps, st = construct(rng, cfg.model, cfg) + + # DataLoader + should_log() && println("$(now()) => creating dataloaders") + ds_train, ds_val = construct(cfg.dataset) + _, ds_train_state = iterate(ds_train) + + # Optimizer and Scheduler + should_log() && println("$(now()) => creating optimizer") + opt, scheduler = construct(cfg.optimizer) + opt_state = Optimisers.setup(opt, ps) + if is_distributed() + opt_state = FluxMPI.synchronize!(opt_state) + should_log() && println("$(now()) ==> synced optimiser state across all ranks") + end + + expt_name = ("name-$(cfg.model.name)_arch-$(cfg.model.arch)_id-$(cfg.train.expt_id)") + ckpt_dir = joinpath(cfg.train.expt_subdir, cfg.train.checkpoint_dir, expt_name) + log_dir = joinpath(cfg.train.expt_subdir, cfg.train.log_dir, expt_name) + if cfg.train.resume == "" + rpath = joinpath(ckpt_dir, "model_current.jlso") + else + rpath = cfg.train.resume + end + + ckpt = load_checkpoint(rpath) + if !isnothing(ckpt) + ps = ckpt.ps |> gpu + st = ckpt.st |> gpu + opt_state = fmap(gpu, ckpt.opt_state) + initial_step = ckpt.step + should_log() && println("$(now()) ==> training started from $initial_step") + else + initial_step = 1 + end + + validate(ds_val, model, ps, st, 0, cfg.train.total_steps) + cfg.train.evaluate && return + + GC.gc(true) + CUDA.reclaim() + batch_time = AverageMeter("Batch Time", "6.3f") data_time = AverageMeter("Data Time", "6.3f") forward_time = AverageMeter("Forward Pass Time", "6.3f") @@ -428,153 +213,76 @@ function train(train_loader, model, ps, st, optimiser_state, epoch, args) losses = AverageMeter("Loss", ".4e") top1 = AverageMeter("Acc@1", "6.2f") top5 = AverageMeter("Acc@5", "6.2f") - progress = ProgressMeter(length(train_loader), + + progress = ProgressMeter(cfg.train.total_steps, (batch_time, data_time, forward_time, backward_time, - optimize_time, losses, top1, top5), "Epoch: [$epoch]") + optimize_time, losses, top1, top5), "Train: ") st = Lux.trainmode(st) - t = time() - for (i, (x, y)) in enumerate(CuIterator(train_loader)) - t_data, t = time() - t, time() + for step in initial_step:(cfg.train.total_steps) + # Train Step + t = time() + (x, y), ds_train_state = iterate(ds_train, ds_train_state) + x = x |> gpu + y = y |> gpu + t_data = time() - t + + bsize = size(x, ndims(x)) # Gradients and Update - (loss, ŷ, st), back = Zygote.pullback(p -> logitcrossentropyloss(x, y, model, p, - st), ps) + (loss, st, stats), back = Zygote.pullback(p -> loss_function(model, p, st, (x, y)), + ps) t_forward, t = time() - t, time() gs = back((one(loss) / total_workers(), nothing, nothing))[1] t_backward, t = time() - t, time() if is_distributed() gs = FluxMPI.allreduce_gradients(gs) end - optimiser_state, ps = Optimisers.update!(optimiser_state, ps, gs) + opt_state, ps = Optimisers.update!(opt_state, ps, gs) t_opt = time() - t # Metrics - acc1, acc5 = accuracy(cpu(ŷ), cpu(y), (1, 5)) - update!(top1, acc1, size(x, ndims(x))) - update!(top5, acc5, size(x, ndims(x))) - update!(losses, loss, size(x, ndims(x))) + acc1, acc5 = accuracy(cpu(stats.y_pred), cpu(y), (1, 5)) + top1(acc1, bsize) + top5(acc5, bsize) + losses(loss, bsize) # Measure Elapsed Time - update!(data_time, t_data, size(x, ndims(x))) - update!(forward_time, t_forward, size(x, ndims(x))) - update!(backward_time, t_backward, size(x, ndims(x))) - update!(optimize_time, t_opt, size(x, ndims(x))) - update!(batch_time, t_data + t_forward + t_backward + t_opt, 1) + data_time(t_data, bsize) + forward_time(t_forward, bsize) + backward_time(t_backward, bsize) + optimize_time(t_opt, bsize) + batch_time(t_data + t_forward + t_backward + t_opt, bsize) # Print Progress - if i % args["print-freq"] == 0 || i == length(train_loader) - should_log() && print_meter(progress, i) + if step % cfg.train.print_frequency == 1 || step == cfg.train.total_steps + should_log() && print_meter(progress, step) + reset_meter!(progress) end - t = time() - end - - return ps, st, optimiser_state, (top1.average, top5.average, losses.average) -end - -# Main Function -function main(args) - best_acc1 = 0 - - # Seeding - rng = Random.default_rng() - Random.seed!(rng, args["seed"]) - - # Model Construction - if should_log() - if args["pretrained"] - println("$(now()) => using pre-trained model `$(args["arch"])`") - else - println("$(now()) => creating model `$(args["arch"])`") + if step % cfg.train.evaluate_every == 0 + acc1 = validate(ds_val, model, ps, st, step, cfg.train.total_steps) + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + save_state = (ps=ps |> cpu, st=st |> cpu, opt_state=fmap(cpu, opt_state), + step=step) + if should_log() + save_checkpoint(save_state; is_best, + filename=joinpath(ckpt_dir, "model_$(step).jlso")) + end end - end - model, ps, st = get_model(args["arch"], IMAGENET_MODELS_DICT, rng; warmup=true, - pretrain=args["pretrained"]) - - normalization_parameters = (mean=reshape([0.485f0, 0.456f0, 0.406f0], 1, 1, 3), - std=reshape([0.229f0, 0.224f0, 0.225f0], 1, 1, 3)) - train_data_augmentation = Resize(256, 256) |> FlipX(0.5) |> RCropSize(224, 224) - val_data_augmentation = Resize(256, 256) |> CropSize(224, 224) - train_dataset = ImageDataset(joinpath(args["data"], "train"), train_data_augmentation, - normalization_parameters) - val_dataset = ImageDataset(joinpath(args["data"], "val"), val_data_augmentation, - normalization_parameters) - if is_distributed() - train_dataset = DistributedDataContainer(train_dataset) - val_dataset = DistributedDataContainer(val_dataset) - end - train_loader = DataLoader(shuffleobs(train_dataset), - args["batch-size"] ÷ total_workers()) - val_loader = DataLoader(val_dataset, args["batch-size"] ÷ total_workers()) + # LR Update + opt_state = Optimisers.adjust(opt_state, scheduler(step + 1)) - # Optimizer and Scheduler - should_log() && println("$(now()) => creating optimiser") - optimiser = Optimisers.OptimiserChain(Optimisers.Momentum(args["learning-rate"], - args["momentum"]), - Optimisers.WeightDecay(args["weight-decay"])) - optimiser_state = Optimisers.setup(optimiser, ps) - if is_distributed() - optimiser_state = FluxMPI.synchronize!(optimiser_state) - should_log() && println("$(now()) ==> synced optimiser state across all ranks") - end - scheduler = Step(; λ=args["learning-rate"], γ=0.1f0, step_sizes=30) - - if args["resume"] != "" - if isfile(args["resume"]) - checkpoint = deserialize(args["resume"]) - args["start-epoch"] = checkpoint["epoch"] - optimiser_state = checkpoint["optimiser_state"] |> gpu - ps = checkpoint["model_parameters"] |> gpu - st = checkpoint["model_states"] |> gpu - should_log() && - println("$(now()) => loaded checkpoint `$(args["resume"])` (epoch $(args["start-epoch"]))") - else - should_log() && - println("$(now()) => no checkpoint found at `$(args["resume"])`") - end - end - - if args["evaluate"] - @assert !is_distributed() "We are not syncing statistics. For evaluation run on 1 process" - validate(val_loader, model, ps, st, args) - return + t = time() end - GC.gc(true) - CUDA.reclaim() - - for epoch in args["start-epoch"]:args["epochs"] - # Train for 1 epoch - ps, st, optimiser_state, _ = train(train_loader, model, ps, st, optimiser_state, - epoch, args) - - # Some Housekeeping - GC.gc(true) - CUDA.reclaim() - - # Evaluate on validation set - acc1, _, _ = validate(val_loader, model, ps, st, args) - - # ParameterSchedulers - eta_new = scheduler(epoch) - optimiser_state = update_lr(optimiser_state, eta_new) - - # Some Housekeeping - GC.gc(true) - CUDA.reclaim() - - # Remember Best Accuracy and Save Checkpoint - is_best = acc1 > best_acc1 - best_acc1 = max(acc1, best_acc1) - - save_state = Dict("epoch" => epoch, "arch" => args["arch"], - "model_states" => st |> cpu, "model_parameters" => ps |> cpu, - "optimiser_state" => optimiser_state |> cpu) - save_checkpoint(save_state, is_best) - end + return end -main(parse_commandline_arguments()) +if abspath(PROGRAM_FILE) == @__FILE__ + main(define_configuration(ARGS, ExperimentConfig, Dict{String, Any}())) +end diff --git a/examples/ImageNet/utils.jl b/examples/ImageNet/utils.jl new file mode 100644 index 000000000..7bfe55ec2 --- /dev/null +++ b/examples/ImageNet/utils.jl @@ -0,0 +1,198 @@ +CUDA.allowscalar(false) + +# unsafe_free OneHotArrays +CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) + +# Loss Function +logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) + +function logitcrossentropyloss(x, y, model, ps, st) + ŷ, st_ = model(x, ps, st) + return logitcrossentropyloss(ŷ, y), ŷ, st_ +end + +# Random +function get_prng(seed::Int) + @static if VERSION >= v"1.7" + return Xoshiro(seed) + else + return MersenneTwister(seed) + end +end + +# Accuracy +function accuracy(ŷ, y, topk=(1,)) + maxk = maximum(topk) + + pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev=true) + true_labels = onecold(y) + + accuracies = Vector{Float32}(undef, length(topk)) + + for (i, k) in enumerate(topk) + accuracies[i] = sum(map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, + true_labels)) + end + + return accuracies .* 100 ./ size(y, ndims(y)) +end + +# Distributed Utils +is_distributed() = FluxMPI.Initialized() && total_workers() > 1 +should_log() = !FluxMPI.Initialized() || local_rank() == 0 + +# Checkpointing +function save_checkpoint(state::NamedTuple; is_best::Bool, filename::String) + isdir(dirname(filename)) || mkpath(dirname(filename)) + JLSO.save(filename, :state => state) + is_best && _symlink_safe(filename, joinpath(dirname(filename), "model_best.jlso")) + _symlink_safe(filename, joinpath(dirname(filename), "model_current.jlso")) + return nothing +end + +function _symlink_safe(src, dest) + rm(dest; force=true) + return symlink(src, dest) +end + +function load_checkpoint(fname::String) + try + # NOTE(@avik-pal): ispath is failing for symlinks? + return JLSO.load(fname)[:state] + catch + @warn """$fname could not be loaded. This might be because the file is absent or is + corrupt. Proceeding by returning `nothing`.""" + return nothing + end +end + +# Parameter Scheduling +## Copied from ParameterSchedulers.jl due to its heavy dependencies +struct CosineAnnealSchedule{restart, T, S <: Integer} + range::T + offset::T + dampen::T + period::S + + function CosineAnnealSchedule(lambda_0, lambda_1, period; restart::Bool=true, + dampen=1.0f0) + range = abs(lambda_0 - lambda_1) + offset = min(lambda_0, lambda_1) + return new{restart, typeof(range), typeof(period)}(range, offset, dampen, period) + end +end + +function (s::CosineAnnealSchedule{true})(t) + d = s.dampen^div(t - 1, s.period) + return (s.range * (1 + cos(pi * mod(t - 1, s.period) / s.period)) / 2 + s.offset) / d +end + +function (s::CosineAnnealSchedule{false})(t) + return s.range * (1 + cos(pi * (t - 1) / s.period)) / 2 + s.offset +end + +struct Step{T, S} + start::T + decay::T + step_sizes::S + + function Step(start::T, decay::T, step_sizes::S) where {T, S} + _step_sizes = (S <: Integer) ? Iterators.repeated(step_sizes) : step_sizes + + return new{T, typeof(_step_sizes)}(start, decay, _step_sizes) + end +end + +(s::Step)(t) = s.start * s.decay^(searchsortedfirst(s.step_sizes, t - 1) - 1) + +struct ConstantSchedule{T} + val::T +end + +(s::ConstantSchedule)(t) = s.val + +# Tracking +Base.@kwdef mutable struct AverageMeter + fmtstr::Any + val::Float64 = 0.0 + sum::Float64 = 0.0 + count::Int = 0 + average::Float64 = 0 +end + +function AverageMeter(name::String, fmt::String) + fmtstr = Formatting.FormatExpr("$name {1:$fmt} ({2:$fmt})") + return AverageMeter(; fmtstr=fmtstr) +end + +function (meter::AverageMeter)(val, n::Int) + meter.val = val + s = val * n + if is_distributed() + v = [s, typeof(val)(n)] + v = FluxMPI.MPIExtensions.allreduce!(v, +, FluxMPI.MPI.COMM_WORLD) + s = v[1] + n = Int(v[2]) + end + meter.sum += s + meter.count += n + meter.average = meter.sum / meter.count + return meter.average +end + +function reset_meter!(meter::AverageMeter) + meter.val = 0.0 + meter.sum = 0.0 + meter.count = 0 + meter.average = 0.0 + return meter +end + +function print_meter(meter::AverageMeter) + return Formatting.printfmt(meter.fmtstr, meter.val, meter.average) +end + +# ProgressMeter +struct ProgressMeter{N} + batch_fmtstr::Any + meters::NTuple{N, AverageMeter} +end + +function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} + fmt = "%" * string(length(string(num_batches))) * "d" + prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" + batch_fmtstr = Formatting.generate_formatter("$prefix[$fmt/" * + Formatting.sprintf1(fmt, num_batches) * + "]") + return ProgressMeter{N}(batch_fmtstr, meters) +end + +function reset_meter!(meter::ProgressMeter) + reset_meter!.(meter.meters) + return meter +end + +function print_meter(meter::ProgressMeter, batch::Int) + base_str = meter.batch_fmtstr(batch) + print(base_str) + foreach(x -> (print("\t"); print_meter(x)), meter.meters[1:end]) + println() + return nothing +end + +get_loggable_values(meter::ProgressMeter) = getproperty.(meter.meters, :average) + +# Optimisers State +function Lux.cpu(l::Optimisers.Leaf) + @set! l.state = cpu(l.state) + return l +end + +function Lux.gpu(l::Optimisers.Leaf) + @set! l.state = gpu(l.state) + return l +end + +function logitcrossentropy(y_pred, y; dims=1) + return mean(-sum(y .* logsoftmax(y_pred; dims=dims); dims=dims)) +end diff --git a/examples/Project.toml b/examples/Project.toml index 471a6cb80..8deebf5e7 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,31 +1,19 @@ [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" -Augmentor = "02898b10-1f73-11ea-317c-6393d7073e15" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" -Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" -Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" -JpegTurbo = "b835a17e-a41a-41e7-81f0-2f016b05efe0" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -35,30 +23,19 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractDifferentiation = "0.4" -ArgParse = "1" -Augmentor = "0.6" CUDA = "3" ComponentArrays = "0.13" DataLoaders = "0.1" DiffEqSensitivity = "6" -Flux = "0.13" -FluxMPI = "0.5.3, 0.6" -Formatting = "0.4.2" ForwardDiff = "0.10" Functors = "0.2, 0.3" -ImageMagick = "1" -Images = "0.24, 0.25" -JpegTurbo = "0.1" Lux = "0.4" -MLDataUtils = "0.5" MLDatasets = "0.5, 0.7" MLUtils = "0.2" -Metalhead = "0.7" NNlib = "0.8" OneHotArrays = "0.1" Optimisers = "0.2" OrdinaryDiffEq = "6" -ParameterSchedulers = "0.3" Plots = "1" ReverseDiff = "1" Setfield = "0.8, 1"