diff --git a/README.md b/README.md index 1d2990b0..614e9fb0 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ -------------------------------------------------------------------------------- -FlagAI (Fast LArge-scale General AI models) is an fast, easy-to-use and extensible toolkit for large-scale model. Our goal is to support training, fine-tuning, and deployment of large-scale models on various downstream tasks with multi-modality. Currently, we are focusing on NLP models and tasks. In near futher, we will support for other modalities. +FlagAI (Fast LArge-scale General AI models) is a fast, easy-to-use and extensible toolkit for large-scale model. Our goal is to support training, fine-tuning, and deployment of large-scale models on various downstream tasks with multi-modality. Currently, we are focusing on NLP models and tasks. In near futher, we will support for other modalities. * Now it supports **WuDao GLM** with a maximum of 10 billion parameters (see [Introduction to GLM](/docs/GLM.md)). It also supports **BERT**, **RoBERTa**, **GPT2**, **T5**, and models from Huggingface Transformers. @@ -18,7 +18,7 @@ FlagAI (Fast LArge-scale General AI models) is an fast, easy-to-use and extensib * FlagAI is backed by the three most popular data/model parallel libraries — [PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM) — with seamless integration between them. Users can parallel their training/testing process with less than ten lines of code. -The code is partially based on [GLM](https://github.com/THUDM/GLM), [Transformers](https://github.com/huggingface/transformers) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). +The code is partially based on [GLM](https://github.com/THUDM/GLM), [Transformers](https://github.com/huggingface/transformers), [timm](https://github.com/rwightman/pytorch-image-models) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). @@ -156,7 +156,6 @@ start with our [contributor guidelines](CONTRIBUTING.md) and then check these [open issues](https://github.com/BAAI-Open/FlagAI/issues) for specific tasks. ## Contact us -Scan wechat QR code diff --git a/README_zh.md b/README_zh.md index 88def64d..e8dc5a3f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -18,7 +18,7 @@ * 飞智由三个最流行的数据/模型并行库([PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM))提供支持,它们之间实现了无缝集成。 你可以用不到十行代码来并行你的训练/测试过程。 -本项目的部分代码基于[GLM](https://github.com/THUDM/GLM),[Transformers](https://github.com/huggingface/transformers) 和 [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). +本项目的部分代码基于 [GLM](https://github.com/THUDM/GLM),[Transformers](https://github.com/huggingface/transformers),[timm](https://github.com/rwightman/pytorch-image-models) 和 [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). diff --git a/doc_zh/TUTORIAL_4_TRAINER.md b/doc_zh/TUTORIAL_4_TRAINER.md index c70c7871..0e83ffb9 100644 --- a/doc_zh/TUTORIAL_4_TRAINER.md +++ b/doc_zh/TUTORIAL_4_TRAINER.md @@ -13,7 +13,7 @@ - [deepspeed](#deepspeed) - [pytorchDDP](#pytorchddp) - [deepspeed + megatron-lm](#deepspeed--megatron-lm) - +- [EnvTrainer](#EnvTrainer) Trainer 类提供了API用于多种并行框架的训练。API 支持在多个 GPU上使用Pytorch DDP/Deepspeed进行分布式训练,同时支持Megatron-LM+Deepspeed的混合并行分布式训练,同时也通过 NVIDIA Apex 实现混合精度。 ## 入门 @@ -335,3 +335,72 @@ trainer = MyTrainer( ) ``` +# EnvTrainer + +为了更容易的输入参数,我们提供了EnvTrainer代替原来的Trainer +例如: +```python +# train.py +import torch +from flagai.env_args import EnvArgs +from flagai.env_trainer import EnvTrainer + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +env_args = EnvArgs( + env_type="pytorch", + experiment_name="vit-cifar100-single_gpu", + batch_size=150, + num_gpus=1, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100_single_gpu", + save_interval=1000, + num_checkpoints=1, +) + +env_args.add_arg(arg_name="test1", default=0, type=int, ) +env_args_parse = env_args.parse_args() +trainer = EnvTrainer(env_args) +``` + +运行train.py文件时,可以通过命令行修改输入参数。 +```commandline +python train.py --batch_size=8 --epochs=10 +``` +如果你需要添加额外的参数,你可以调用这个函数: +```python +env_args.add_arg(arg_name="test1", default=0, type=int, ) +``` +然后你可以运行如下命令中的train.py文件: +```commandline +python train.py --test1=1 +``` +更多的例子可以查看 : + +1. [vit-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/vit_cifar100/train_env_trainer.py) + +2. [glm-title-generation-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/glm_title_generation/train_env_trainer.py) + + +# 使用 pytorchDDP launcher 或 deepspeed launcher 运行 +如果你使用多个GPU来训练模型,你可以直接运行train.py来调用FlagAI训练器中的启动器。 +```commandline +python train.py +``` +另外,你也可以使用pytorchDDP和deepspeed启动器来运行,例如: +### pytorchDDP +```commandline +python -m torch.distributed.launch --nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 17750 train_env_trainer.py --not_call_launch +``` +### deepspeed +```commandline +python -m deepspeed.launcher.launch --master_addr=172.31.125.121 --master_port=17500 train.py --not_call_launch +``` diff --git a/docs/TUTORIAL_4_TRAINER.md b/docs/TUTORIAL_4_TRAINER.md index f78526b4..2ec7d235 100644 --- a/docs/TUTORIAL_4_TRAINER.md +++ b/docs/TUTORIAL_4_TRAINER.md @@ -13,6 +13,9 @@ - [deepspeed](#deepspeed) - [pytorchDDP](#pytorchddp) - [deepspeed + megatron-lm](#deepspeed--megatron-lm) +- [EnvTrainer](#EnvTrainer) + + The Trainer class provides APIs for training with multiple parallel frameworks. The API supports distributed training with Pytorch DDP/Deepspeed on multiple GPUs, as well as mixed parallel distributed training with Megatron-LM+Deepspeed, and mixed precision via NVIDIA Apex. ## Getting Started @@ -341,3 +344,76 @@ trainer = MyTrainer( model_paralle_size = 2 ) ``` + +# EnvTrainer + +To input the parameters easier, we provided the EnvTrainer to replace the original Tranier. + +Taking the code for example: +```python +# train.py +import torch +from flagai.env_args import EnvArgs +from flagai.env_trainer import EnvTrainer + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +env_args = EnvArgs( + env_type="pytorch", + experiment_name="vit-cifar100-single_gpu", + batch_size=150, + num_gpus=1, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100_single_gpu", + save_interval=1000, + num_checkpoints=1, +) + +env_args.add_arg(arg_name="test1", default=0, type=int, ) +env_args_parse = env_args.parse_args() +trainer = EnvTrainer(env_args) +``` + +When you run the train.py file, you can modify the input parameters through command line. +```commandline +python train.py --batch_size=8 --epochs=10 +``` +If you need to add additional parameters, you can call the function: +```python +env_args.add_arg(arg_name="test1", default=0, type=int, ) +``` +Then you can run the train.py file in the following command: +```commandline +python train.py --test1=1 +``` + +More examples in : + +1. [vit-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/vit_cifar100/train_env_trainer.py) + +2. [glm-title-generation-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/glm_title_generation/train_env_trainer.py) + + +# Run with pytorchDDP launcher or deepspeed launcher +If you use multiple GPU to train models, you can run the train.py directly which to call the launcher in FlagAI Trainer. +```commandline +python train.py +``` +In addition, you also can use the pytorchDDP and deepspeed launcher to run, as example: + +### pytorchDDP +```commandline +python -m torch.distributed.launch --nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 17750 train_env_trainer.py --not_call_launch +``` +### deepspeed +```commandline +python -m deepspeed.launcher.launch --master_addr=172.31.125.121 --master_port=17500 train.py --not_call_launch +``` \ No newline at end of file diff --git a/examples/glm_title_generation/train_env_trainer.py b/examples/glm_title_generation/train_env_trainer.py new file mode 100644 index 00000000..39c4524f --- /dev/null +++ b/examples/glm_title_generation/train_env_trainer.py @@ -0,0 +1,144 @@ +# Copyright © 2022 BAAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +import os +import numpy as np +import torch +from torch.utils.data import Dataset +from flagai.auto_model.auto_loader import AutoLoader +from flagai.env_trainer import EnvTrainer +from flagai.env_args import EnvArgs +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# You can input all parameters by the command line. +# For example: python train_env_trainer.py --epochs=300 --batch_size=4 --env_type=pytorch +env_args = EnvArgs() +trainer = EnvTrainer(env_args) + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +src_dir = cur_dir + '/data/train.src' +tgt_dir = cur_dir + '/data/train.tgt' + +maxlen = 256 +auto_loader = AutoLoader("lm", + model_name="GLM-large-ch", + model_dir="./state_dict/") +model = auto_loader.get_model() +tokenizer = auto_loader.get_tokenizer() + +def read_file(): + src = [] + tgt = [] + + with open(src_dir, 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + src.append(line.strip('\n').lower()) + + with open(tgt_dir, 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + tgt.append(line.strip('\n').lower()) + + return src, tgt + + +class GLMSeq2seqDataset(Dataset): + + def __init__(self, + sents_src, + sents_tgt, + tokenizer, + max_src_length=300, + max_tgt_length=200): + super(GLMSeq2seqDataset, self).__init__() + self.sents_src = sents_src + self.sents_tgt = sents_tgt + self.tokenizer = tokenizer + self.max_src_length = max_src_length + self.max_tgt_length = max_tgt_length + self.no_block_position = False + + def __getitem__(self, i): + source_text = self.sents_src[i] + target_text = self.sents_tgt[i] + data = self.tokenizer.encode_plus(source_text, target_text) + + return data + + def __len__(self): + + return len(self.sents_src) + + +class GLMPoetryDynamicCollateFN(): #padding process in each batch + + def __init__(self, pad_id): + self.pad_id = pad_id + + def pad_token(self, tokens, max_length): + pad_len = max_length - len(tokens) + tokens += [self.pad_id] * pad_len + return tokens + + def pad_position_ids(self, position_ids, max_length): + pad_len = max_length - len(position_ids[0]) + position_ids[0] += [len(position_ids[0]) + x for x in range(pad_len)] + position_ids[1] += [1] * pad_len + return position_ids + + def pad_loss_mask(self, loss_mask, max_length): + pad_len = max_length - len(loss_mask) + loss_mask += [0] * pad_len + return loss_mask + + def __call__(self, batch): + input_ids = [data["input_ids"] for data in batch] + target_ids = [data["target_ids"] for data in batch] + position_ids = [data["position_ids"] for data in batch] + attention_mask = [data['attention_mask'] for data in batch] + loss_mask = [data['loss_mask'] for data in batch] + + max_length = max([len(t) for t in input_ids]) + for i in range(len(input_ids)): + input_ids[i] = self.pad_token(input_ids[i], max_length) + target_ids[i] = self.pad_token(target_ids[i], max_length) + position_ids[i] = self.pad_position_ids(position_ids[i], + max_length) + loss_mask[i] = self.pad_loss_mask(loss_mask[i], max_length) + return { + 'input_ids': torch.LongTensor(input_ids), + 'labels': torch.LongTensor(target_ids), + 'position_ids': torch.LongTensor(position_ids), + 'attention_mask': torch.LongTensor(attention_mask), + 'loss_mask': torch.LongTensor(loss_mask) + } + + +sents_src, sents_tgt = read_file() +my_collate_fn = GLMPoetryDynamicCollateFN( + pad_id=tokenizer.get_command('pad').Id) + +data_len = len(sents_tgt) +train_size = int(data_len * 0.8) +train_src = sents_src[:train_size][:2000] +train_tgt = sents_tgt[:train_size][:2000] + +val_src = sents_src[train_size:] +val_tgt = sents_tgt[train_size:] + +train_dataset = GLMSeq2seqDataset(train_src, + train_tgt, + tokenizer=tokenizer, + max_src_length=300, + max_tgt_length=200) +val_dataset = GLMSeq2seqDataset(val_src, + val_tgt, + tokenizer=tokenizer, + max_src_length=300, + max_tgt_length=200) + +trainer.train(model, + train_dataset=train_dataset, + valid_dataset=val_dataset, + collate_fn=my_collate_fn) diff --git a/examples/vit_cifar100/README.md b/examples/vit_cifar100/README.md new file mode 100644 index 00000000..1cd35edd --- /dev/null +++ b/examples/vit_cifar100/README.md @@ -0,0 +1,163 @@ +# Vit for classification with cifar100 dataset + +Vision Transformer(Vit) is becoming increasingly popular in the field of +compute vision(CV). More and more tasks are using Vit to achieve the SOTA. + +The paper is in https://arxiv.org/pdf/2010.11929.pdf. + +Code is in https://github.com/google-research/vision_transformer. + +## How to use +We can easily use the Vit to finetune cifar100 dataset. +### Training +```python +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.trainer import Trainer +from flagai.auto_model.auto_loader import AutoLoader + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +trainer = Trainer( + env_type="pytorch", + experiment_name="vit-cifar100", + batch_size=64, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100", + save_interval=1000, + num_checkpoints=1, +) + +def build_cifar(): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.Resize(224), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR100(root="./cifar100", train=True, download=True, transform=transform_train) + test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test) + return train_dataset, test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + + +if __name__ == '__main__': + loader = AutoLoader(task_name="backbone", + model_name="Vit-base-p16", + num_classes=100) + + model = loader.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs) + + train_dataset, val_dataset = build_cifar() + trainer.train(model, + optimizer=optimizer, + lr_scheduler=scheduler, + train_dataset=train_dataset, + valid_dataset=val_dataset, + metric_methods=[["accuracy", validate]], + collate_fn=collate_fn) +``` + +### Validation +If you have trained a model, you can valite it again by following code. +```python +import torch +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.auto_model.auto_loader import AutoLoader +import os +from tqdm import tqdm + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def build_cifar(): + + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test) + return test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + +if __name__ == '__main__': + + model_save_dir = "./checkpoints_vit_cifar100" + print(f"loadding model in :{model_save_dir}") + loader = AutoLoader(task_name="backbone", + model_name="Vit-base-p16", + num_classes=100) + + model = loader.get_model() + + model.load_state_dict(torch.load(os.path.join(model_save_dir, "38000", "pytorch_model.bin"), map_location=device)["module"]) + print(f"model load success.......") + model.to(device) + + val_dataset = build_cifar() + + val_dataloader = DataLoader(val_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_fn) + index = 0 + accuracy = 0.0 + for data in tqdm(val_dataloader, total=len(val_dataloader)): + index += 1 + data = {k: v.to(device) for k, v in data.items()} + labels = data["labels"] + pred = model(**data)["logits"] + acc = validate(pred, labels) + accuracy += acc + + print(f"accuracy is {accuracy / index}") +``` diff --git a/examples/vit_cifar100/deepspeed.json b/examples/vit_cifar100/deepspeed.json new file mode 100644 index 00000000..f2339ca3 --- /dev/null +++ b/examples/vit_cifar100/deepspeed.json @@ -0,0 +1,48 @@ +{ + "train_micro_batch_size_per_gpu": 64, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": false, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e7, + "allgather_bucket_size": 5e7, + "cpu_offload": true + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 2000 + } + }, + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5, + "weight_decay": 0.1, + "betas": [ + 0.9, + 0.98 + ], + "eps": 1e-6 + } + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false + } diff --git a/examples/vit_cifar100/hostfile b/examples/vit_cifar100/hostfile new file mode 100644 index 00000000..51356577 --- /dev/null +++ b/examples/vit_cifar100/hostfile @@ -0,0 +1 @@ +127.0.0.1 slots=2 \ No newline at end of file diff --git a/examples/vit_cifar100/train_DDP.py b/examples/vit_cifar100/train_DDP.py new file mode 100644 index 00000000..35c997b1 --- /dev/null +++ b/examples/vit_cifar100/train_DDP.py @@ -0,0 +1,87 @@ +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.trainer import Trainer +from flagai.auto_model.auto_loader import AutoLoader + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +env_type = "pytorchDDP" +trainer = Trainer( + env_type=env_type, + experiment_name="vit-cifar100-8gpu", + batch_size=150, + num_gpus=8, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100_8gpu", + save_interval=1000, + num_checkpoints=1, + hostfile="./hostfile", + training_script="train_DDP.py" +) + +def build_cifar(): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.Resize(224), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train) + test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test) + return train_dataset, test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + if trainer.fp16: + images = images.half() + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + +if __name__ == '__main__': + loader = AutoLoader(task_name="classification", + model_name="vit-base-p16-224", + num_classes=100) + + model = loader.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs) + train_dataset, val_dataset = build_cifar() + + trainer.train(model, + optimizer=optimizer, + lr_scheduler=scheduler, + train_dataset=train_dataset, + valid_dataset=val_dataset, + metric_methods=[["accuracy", validate]], + collate_fn=collate_fn) + + + + + diff --git a/examples/vit_cifar100/train_deepspeed.py b/examples/vit_cifar100/train_deepspeed.py new file mode 100644 index 00000000..9d44b1df --- /dev/null +++ b/examples/vit_cifar100/train_deepspeed.py @@ -0,0 +1,88 @@ +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.trainer import Trainer +from flagai.auto_model.auto_loader import AutoLoader + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +env_type = "deepspeed" +trainer = Trainer( + env_type=env_type, + experiment_name="vit-cifar100-deepspeed", + batch_size=150, + num_gpus=8, + fp16=True, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100_deepspeed", + save_interval=1000, + num_checkpoints=1, + hostfile="./hostfile", + training_script="train_deepspeed.py" +) + +def build_cifar(): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.Resize(224), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train) + test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test) + return train_dataset, test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + if trainer.fp16: + images = images.half() + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + +if __name__ == '__main__': + loader = AutoLoader(task_name="classification", + model_name="vit-base-p16-224", + num_classes=100) + + model = loader.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs) + train_dataset, val_dataset = build_cifar() + + trainer.train(model, + optimizer=optimizer, + lr_scheduler=scheduler, + train_dataset=train_dataset, + valid_dataset=val_dataset, + metric_methods=[["accuracy", validate]], + collate_fn=collate_fn) + + + + + diff --git a/examples/vit_cifar100/train_env_trainer.py b/examples/vit_cifar100/train_env_trainer.py new file mode 100644 index 00000000..fcb153cb --- /dev/null +++ b/examples/vit_cifar100/train_env_trainer.py @@ -0,0 +1,90 @@ +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.env_trainer import EnvTrainer +from flagai.auto_model.auto_loader import AutoLoader +import argparse + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from flagai.env_args import EnvArgs + +env_args = EnvArgs( + env_type="pytorch", + experiment_name="vit-cifar100-single_gpu", + batch_size=64, + num_gpus=1, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100_single_gpu", + save_interval=1000, + num_checkpoints=1, +) + +env_args.add_arg(arg_name="test_args", default=0, type=int, ) +env_args = env_args.parse_args() +trainer = EnvTrainer(env_args) + +def build_cifar(): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.Resize(224), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train) + test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test) + return train_dataset, test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + if trainer.fp16: + images = images.half() + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + +if __name__ == '__main__': + loader = AutoLoader(task_name="classification", + model_name="vit-base-p16-224", + num_classes=100) + + model = loader.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs) + train_dataset, val_dataset = build_cifar() + + trainer.train(model, + optimizer=optimizer, + lr_scheduler=scheduler, + train_dataset=train_dataset, + valid_dataset=val_dataset, + metric_methods=[["accuracy", validate]], + collate_fn=collate_fn) + + + + + diff --git a/examples/vit_cifar100/train_single_gpu.py b/examples/vit_cifar100/train_single_gpu.py new file mode 100644 index 00000000..ef7e1356 --- /dev/null +++ b/examples/vit_cifar100/train_single_gpu.py @@ -0,0 +1,85 @@ +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.trainer import Trainer +from flagai.auto_model.auto_loader import AutoLoader + +lr = 2e-5 +n_epochs = 50 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +env_type = "pytorch" +trainer = Trainer( + env_type=env_type, + experiment_name="vit-cifar100-single_gpu", + batch_size=150, + num_gpus=1, + gradient_accumulation_steps=1, + lr=lr, + weight_decay=1e-5, + epochs=n_epochs, + log_interval=100, + eval_interval=1000, + load_dir=None, + pytorch_device=device, + save_dir="checkpoints_vit_cifar100_single_gpu", + save_interval=1000, + num_checkpoints=1, +) + +def build_cifar(): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.Resize(224), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train) + test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test) + return train_dataset, test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + if trainer.fp16: + images = images.half() + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + +if __name__ == '__main__': + loader = AutoLoader(task_name="classification", + model_name="vit-base-p16-224", + num_classes=100) + + model = loader.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs) + train_dataset, val_dataset = build_cifar() + + trainer.train(model, + optimizer=optimizer, + lr_scheduler=scheduler, + train_dataset=train_dataset, + valid_dataset=val_dataset, + metric_methods=[["accuracy", validate]], + collate_fn=collate_fn) + + + + + diff --git a/examples/vit_cifar100/validate.py b/examples/vit_cifar100/validate.py new file mode 100644 index 00000000..e52eb113 --- /dev/null +++ b/examples/vit_cifar100/validate.py @@ -0,0 +1,76 @@ +import torch +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR100 +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +from flagai.auto_model.auto_loader import AutoLoader +import os +from tqdm import tqdm + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def build_cifar(): + + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test) + return test_dataset + +def collate_fn(batch): + images = torch.stack([b[0] for b in batch]) + labels = [b[1] for b in batch] + labels = torch.tensor(labels).long() + return {"images": images, "labels": labels} + +def validate(logits, labels, meta=None): + _, predicted = logits.max(1) + total = labels.size(0) + correct = predicted.eq(labels).sum().item() + return correct / total + +if __name__ == '__main__': + + model_save_dir = "./checkpoints_vit_cifar100" + print(f"loadding model in :{model_save_dir}") + loader = AutoLoader(task_name="backbone", + model_name="vit-base-p16-224", + num_classes=100) + + model = loader.get_model() + + model.load_state_dict(torch.load(os.path.join(model_save_dir, "38000", "pytorch_model.bin"), map_location=device)["module"]) + print(f"model load success.......") + model.to(device) + + val_dataset = build_cifar() + + val_dataloader = DataLoader(val_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_fn) + index = 0 + accuracy = 0.0 + for data in tqdm(val_dataloader, total=len(val_dataloader)): + index += 1 + data = {k: v.to(device) for k, v in data.items()} + labels = data["labels"] + pred = model(**data)["logits"] + acc = validate(pred, labels) + accuracy += acc + + print(f"accuracy is {accuracy / index}") + + + + + + + + + + diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index 37d7ca90..17d74d3d 100644 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -54,26 +54,37 @@ def __getattr__(self, name): "glm_title-generation": ["flagai.model.glm_model", "GLMForSeq2Seq"], "opt_seq2seq": ("flagai.model.opt_model","OPTModel"), "opt_lm": ("flagai.model.opt_model","OPTModel"), + "vit_classification": ("flagai.model.vision.vit", "VisionTransformer") + } MODEL_DICT = { - "bert-base-en": ["flagai.model.bert_model", "BertModel", "bert"], - "roberta-base-ch": ["flagai.model.bert_model", "BertModel", "bert"], - "t5-base-en": ["flagai.model.t5_model", "T5Model", "t5"], - "t5-base-ch": ["flagai.model.t5_model", "T5Model", "t5"], - "glm-large-ch": ["flagai.model.glm_model", "GLMModel", "glm"], - "glm-large-en": ["flagai.model.glm_model", "GLMModel", "glm"], - "gpt2-base-ch": ["flagai.model.gpt2_model", "GPT2Model", "gpt2"], - "cpm-large-ch": ["flagai.model.gpt2_model", "GPT2Model", "cpm"], - "opt-125m-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-350m-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-1.3b-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-2.7b-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-6.7b-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-13b-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-30b-en": ["flagai.model.opt_model","OPTModel", "opt"], - "opt-66b-en": ["flagai.model.opt_model","OPTModel", "opt"], - "glm-10b-ch": ["flagai.model.glm_model", "GLMModel", "glm"], + "bert-base-en": ["flagai.model.bert_model", "BertModel", "bert", "nlp"], + "roberta-base-ch": ["flagai.model.bert_model", "BertModel", "bert", "nlp"], + "t5-base-en": ["flagai.model.t5_model", "T5Model", "t5", "nlp"], + "t5-base-ch": ["flagai.model.t5_model", "T5Model", "t5", "nlp"], + "glm-large-ch": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"], + "glm-large-en": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"], + "gpt2-base-ch": ["flagai.model.gpt2_model", "GPT2Model", "gpt2", "nlp"], + "cpm-large-ch": ["flagai.model.gpt2_model", "GPT2Model", "cpm", "nlp"], + "opt-125m-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-350m-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-1.3b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-2.7b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-6.7b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-13b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-30b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "opt-66b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"], + "glm-10b-ch": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"], + + "vit-base-p16-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-base-p16-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-base-p32-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-base-p32-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-large-p16-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-large-p16-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-large-p32-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], + "vit-large-p32-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"], } TOKENIZER_DICT = { @@ -103,10 +114,8 @@ def __getattr__(self, name): "opt-13b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], "opt-30b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], "opt-66b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"], - } - class AutoLoader: def __init__(self, @@ -153,6 +162,8 @@ def __init__(self, return brief_model_name = MODEL_DICT[model_name][2] + model_type = MODEL_DICT[model_name][3] + # The dir to save config, vocab and model. self.model_name = ALL_TASK.get(f"{brief_model_name}_{task_name}", None) @@ -184,38 +195,41 @@ def __init__(self, model_id = _get_model_id(model_name) print("*"*20, task_name, model_id, model_name) - if "glm" in model_name and "ch" in model_name: - vocab_file = os.path.join(download_path,'cog-pretrained.model') - if not os.path.exists(vocab_file): - vocab_file = _get_vocab_path(download_path, "cog-pretrain.model", model_id) - elif "glm" in model_name and "en" in model_name: - vocab_file = "GLM-large-en" - elif model_name == "cpm-large-ch": - # two files to load - vocab_file_1 = os.path.join(download_path, "vocab.json") - vocab_file_2 = os.path.join(download_path, "chinese_vocab.model") - if not os.path.exists(vocab_file_1): - vocab_file_1 = _get_vocab_path(download_path, "vocab.json", - model_id) - if not os.path.exists(vocab_file_2): - vocab_file_2 = _get_vocab_path(download_path, - "chinese_vocab.model", model_id) - else: - vocab_file = os.path.join(download_path, 'vocab.txt') - if not os.path.exists(vocab_file): - vocab_file = _get_vocab_path(download_path, "vocab.txt", - model_id) - tokenizer_class = TOKENIZER_DICT[model_name] - tokenizer_class = getattr(LazyImport(tokenizer_class[0]), - tokenizer_class[1]) - if model_name == "cpm-large-ch": - self.tokenizer = tokenizer_class(vocab_file_1, vocab_file_2) - elif brief_model_name == "opt": - self.tokenizer = tokenizer_class("facebook/opt-350m") - elif model_name in ["glm-large-en", "glm-large-ch"]: - self.tokenizer = tokenizer_class() - else : - self.tokenizer = tokenizer_class(vocab_file) + if model_type == "nlp": + if "glm" in model_name and "ch" in model_name: + vocab_file = os.path.join(download_path,'cog-pretrained.model') + if not os.path.exists(vocab_file): + vocab_file = _get_vocab_path(download_path, "cog-pretrain.model", model_id) + elif "glm" in model_name and "en" in model_name: + vocab_file = "GLM-large-en" + elif model_name == "cpm-large-ch": + # two files to load + vocab_file_1 = os.path.join(download_path, "vocab.json") + vocab_file_2 = os.path.join(download_path, "chinese_vocab.model") + if not os.path.exists(vocab_file_1): + vocab_file_1 = _get_vocab_path(download_path, "vocab.json", + model_id) + if not os.path.exists(vocab_file_2): + vocab_file_2 = _get_vocab_path(download_path, + "chinese_vocab.model", model_id) + else: + vocab_file = os.path.join(download_path, 'vocab.txt') + if not os.path.exists(vocab_file): + vocab_file = _get_vocab_path(download_path, "vocab.txt", + model_id) + tokenizer_class = TOKENIZER_DICT[model_name] + tokenizer_class = getattr(LazyImport(tokenizer_class[0]), + tokenizer_class[1]) + if model_name == "cpm-large-ch": + self.tokenizer = tokenizer_class(vocab_file_1, vocab_file_2) + elif brief_model_name == "opt": + self.tokenizer = tokenizer_class("facebook/opt-350m") + elif model_name in ["glm-large-en", "glm-large-ch"]: + self.tokenizer = tokenizer_class() + else : + self.tokenizer = tokenizer_class(vocab_file) + elif model_type == "vision": + self.tokenizer = None def get_task_name(self, brief_model_name): all_model_task = list(ALL_TASK.keys()) diff --git a/flagai/env_args.py b/flagai/env_args.py new file mode 100644 index 00000000..49c5ce29 --- /dev/null +++ b/flagai/env_args.py @@ -0,0 +1,110 @@ +import argparse + +def save_best(best_score, eval_dict): + return best_score if best_score < eval_dict['loss'] else eval_dict['loss'] + +def str2bool(v): + if isinstance(v,bool): + return v + if v == 'True': + return True + if v == 'False': + return False + +class EnvArgs: + def __init__(self, + env_type="pytorch", + experiment_name="test_experiment", + epochs=1, + batch_size=1, + lr=1e-5, + seed=1234, + + fp16=False, + pytorch_device="cpu", + clip_grad=1.0, + checkpoint_activations=False, + gradient_accumulation_steps=1, + weight_decay=1e-5, + warm_up=0.1, + + log_interval=100, + eval_interval=1000, + save_interval=1000, + + save_dir=None, + load_dir=None, + save_optim=False, # save current optimizer.') + save_rng=False, # save current rng state.') + load_type='latest', # latest, best + load_optim=False, # not load optimizer when loading checkpoint.') + load_rng=False, + tensorboard_dir="tensorboard_summary", + + # distribute settings + deepspeed_activation_checkpointing=False, + num_checkpoints=1, + master_ip='localhost', + master_port=17750, + num_nodes=1, + num_gpus=1, + hostfile="./hostfile", + deepspeed_config="./deepspeed.json", + model_parallel_size=1, + training_script="train.py", + ): + + self.parser = argparse.ArgumentParser(description='Env args parser') + self.parser.add_argument('--env_type', default=env_type, help='the model will be trained') + self.parser.add_argument('--experiment_name', default=experiment_name, help='start training from saved checkpoint') + self.parser.add_argument('--epochs', default=epochs, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--batch_size', default=batch_size, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--lr', default=lr, type=float, help='start training from saved checkpoint') + self.parser.add_argument('--seed', default=seed, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--fp16', default=fp16, type=str2bool, help='start training from saved checkpoint') + self.parser.add_argument('--pytorch_device', default=pytorch_device, help='start training from saved checkpoint') + self.parser.add_argument('--clip_grad', default=clip_grad, type=float, help='start training from saved checkpoint') + self.parser.add_argument('--checkpoint_activations', default=checkpoint_activations, type=str2bool, help='start training from saved checkpoint') + self.parser.add_argument('--gradient_accumulation_steps', default=gradient_accumulation_steps, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--weight_decay', default=weight_decay, type=float, help='start training from saved checkpoint') + self.parser.add_argument('--warm_up', default=warm_up, type=float, help='start training from saved checkpoint') + self.parser.add_argument('--log_interval', default=log_interval, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--eval_interval', default=eval_interval, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--save_interval', default=save_interval, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--save_dir', default=save_dir, help='start training from saved checkpoint') + self.parser.add_argument('--load_dir', default=load_dir, help='start training from saved checkpoint') + self.parser.add_argument('--save_optim', default=save_optim, type=str2bool, help='start training from saved checkpoint') + self.parser.add_argument('--save_rng', default=save_rng, type=str2bool,help='start training from saved checkpoint') + self.parser.add_argument('--load_type', default=load_type, type=str2bool,help='start training from saved checkpoint') + self.parser.add_argument('--load_optim', default=load_optim, type=str2bool,help='start training from saved checkpoint') + self.parser.add_argument('--load_rng', default=load_rng, type=str2bool, help='start training from saved checkpoint') + self.parser.add_argument('--tensorboard_dir', default=tensorboard_dir, help='start training from saved checkpoint') + self.parser.add_argument('--deepspeed_activation_checkpointing', default=deepspeed_activation_checkpointing, help='start training from saved checkpoint') + self.parser.add_argument('--num_checkpoints', default=num_checkpoints, help='start training from saved checkpoint') + self.parser.add_argument('--deepspeed_config', default=deepspeed_config, help='start training from saved checkpoint') + self.parser.add_argument('--model_parallel_size', default=model_parallel_size, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--training_script', default=training_script, help='start training from saved checkpoint') + + self.parser.add_argument('--hostfile', default=hostfile, help='start training from saved checkpoint') + self.parser.add_argument('--master_ip', default=master_ip, help='start training from saved checkpoint') + self.parser.add_argument('--master_port', default=master_port, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--num_nodes', default=num_nodes, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--num_gpus', default=num_gpus, type=int, help='start training from saved checkpoint') + self.parser.add_argument('--not_call_launch', action="store_true", help='start training from saved checkpoint') + self.parser.add_argument('--local_rank', default=0, type=int, help='start training from saved checkpoint') + + def add_arg(self, arg_name, default=None, type=str, help="", store_true=False): + if store_true: + self.parser.add_argument(f"--{arg_name}", default=default, type=type, action="store_true", help=help) + else : + self.parser.add_argument(f"--{arg_name}", default=default, type=type, help=help) + + + def parse_args(self): + args = self.parser.parse_args() + if args.env_type == "pytorch": + # not need the "not_call_launch" parameter + args.not_call_launch = True + + return args + diff --git a/flagai/env_trainer.py b/flagai/env_trainer.py new file mode 100644 index 00000000..c7ef8678 --- /dev/null +++ b/flagai/env_trainer.py @@ -0,0 +1,920 @@ +# Copyright © 2022 BAAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +# Arguments for training +try: + import deepspeed.utils + import deepspeed +except: + pass +try: + from flagai import mpu +except Exception: + pass + +import torch +import argparse +import os +import random +import numpy as np +import torch.distributed as dist +from flagai.logger import log_dist +from torch.utils.tensorboard import SummaryWriter +from flagai.utils import load_checkpoint, save_checkpoint, load_optim, load_rng +from flagai.schedulers import AnnealingLR +from flagai.optimizers import get_optimizer, get_optimizer_param_groups +from flagai.fp16 import FP16_Module +from flagai.utils import Timers +from flagai.launch import launch_dist +from torch.nn.parallel import DistributedDataParallel as DDP +from flagai.fp16 import DynamicLossScaler +""" +The Trainer class, to easily train a pytorh model on a new task. +""" +def save_best(best_score, eval_dict): + return best_score if best_score < eval_dict['loss'] else eval_dict['loss'] + +def get_args_list(env_args): + not_need_to_launch_args = ["not_call_launch", "local_rank", "master_port", "master_ip", "hostfile", "num_gpus", "num_nodes"] + args_list = [] + args = dir(env_args) + for arg in args: + if not arg.startswith("__") and not arg.startswith("_") and arg not in not_need_to_launch_args: + args_list.append(f"--{arg}") + args_list.append(str(getattr(env_args, arg))) + + print(f"args list is {args_list}") + return args_list + +class EnvTrainer(): + def __init__(self, + env_args, + ): + self.timers = Timers() + self.env_type = env_args.env_type + if self.env_type not in set( + ["deepspeed", 'pytorch', 'pytorchDDP', 'deepspeed+mpu']): + raise Exception("Not supported env_type!!!!") + os.environ["ENV_TYPE"] = env_args.env_type + self.experiment_name = env_args.experiment_name + self.batch_size = env_args.batch_size + self.gradient_accumulation_steps = env_args.gradient_accumulation_steps + self.lr = env_args.lr + self.weight_decay = env_args.weight_decay + self.epochs = env_args.epochs + self.clip_grad = env_args.clip_grad + self.seed = env_args.seed + self.fp16 = env_args.fp16 + self.warm_up = env_args.warm_up + + self.log_interval = env_args.log_interval + self.eval_interval = env_args.eval_interval + + # model checkpointing + self.save_dir = env_args.save_dir + self.save_interval = env_args.save_interval + self.save_optim = env_args.save_optim + self.save_rng = env_args.save_rng + self.save_best = save_best + self.load_dir = env_args.load_dir + self.load_type = env_args.load_type + self.load_optim = env_args.load_optim + self.load_rng = env_args.load_rng + self.tb_writer = SummaryWriter( + os.path.join(env_args.tensorboard_dir, env_args.experiment_name)) + + # distribute settings + self.pytorch_device = env_args.pytorch_device + self.checkpoint_activations = env_args.checkpoint_activations + self.deepspeed_activation_checkpointing = env_args.deepspeed_activation_checkpointing + self.num_checkpoints = env_args.num_checkpoints + self.env_type = env_args.env_type + self.not_call_launch = env_args.not_call_launch + self.deepspeed_config = env_args.deepspeed_config + self.model_parallel_size = env_args.model_parallel_size + self.num_nodes = env_args.num_nodes + self.num_gpus = env_args.num_gpus + self.master_ip = env_args.master_ip + self.master_port = env_args.master_port + self.hostfile = env_args.hostfile + self.training_script = env_args.training_script + + if 'deepspeed' in self.env_type or self.env_type == 'pytorchDDP': + training_paras = get_args_list(env_args) + self.rank = int(os.environ.get('RANK', 0)) + self.world_size = int(os.environ.get('WORLD_SIZE', 1)) + self.local_rank = env_args.local_rank + log_dist("not_call_launch: {}".format(self.not_call_launch)) + # Implement for AutoLaunch + # >>> python train.py # will call get_dist_args() + # `--not_call_launch` is default 'False' + # So, if `env_type` is `pytorch`, the `Trainer` will not call lanch_dist() + # Otherwise, the lanch_dist() is called to launch 'train.py' with `--not_call_launch` + if not self.not_call_launch: + launch_dist(launcher='distributed_deepspeed' if 'deepspeed' + in self.env_type else 'distributed_torch', + num_nodes=self.num_nodes, + gpus_per_node=self.num_gpus, + master_addr=self.master_ip, + master_port=self.master_port, + hostfile=self.hostfile, + training_script=self.training_script, + training_paras=training_paras) + os._exit(1) + self.initialize_distributed() + + def set_seed(self, seed=1234): + """Set random seed for reproducability.""" + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if self.env_type == 'deepspeed+mpu': + mpu.model_parallel_cuda_manual_seed(seed) + + def initialize_distributed(self): + """Initialize torch.distributed.""" + if self.env_type == 'pytorch': + log_dist('No need to initialize') + return + if self.env_type in ['deepspeed', 'deepspeed+mpu', 'pytorchDDP']: + torch.backends.cudnn.enabled = False + # Manually set the device ids. + device = self.rank % torch.cuda.device_count() + if self.local_rank is not None: + device = self.local_rank + torch.cuda.set_device(device) + # Call the init process + init_method = 'tcp://' + self.master_ip = os.getenv('MASTER_ADDR', 'localhost') + self.master_port = os.getenv('MASTER_PORT', '6000') + + init_method += self.master_ip + ':' + self.master_port + log_dist( + "init method {}, rank {}, device {}, local_rank {}.".format( + init_method, self.rank, device, self.local_rank)) + torch.distributed.init_process_group( + backend='nccl', # gloo + world_size=self.world_size, + rank=self.rank, + init_method=init_method) + # Set the model-parallel / data-parallel communicators. + if self.env_type == 'deepspeed+mpu': + os.environ["MODEL_PARALLEL_SIZE"] = str(self.model_parallel_size) + try: + mpu.initialize_model_parallel(self.model_parallel_size) + if 'deepspeed' in self.env_type and self.deepspeed_activation_checkpointing: + deepspeed.checkpointing.configure( + mpu, + deepspeed_config=self.deepspeed_config, + num_checkpoints=self.num_checkpoints) + mpu.checkpoint = deepspeed.checkpointing.checkpoint + mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed + except Exception as e: + log_dist(e) + log_dist("No mpu is installed! No model parallel is used") + log_dist("initialize eviroments succesed") + self.set_seed(self.seed) + + def get_dataloader(self, dataset, collate_fn, shuffle=False): + """ initilize the dataloader""" + if dataset is None: + return None + if self.env_type == 'pytorch': + return torch.utils.data.DataLoader(dataset, + batch_size=self.batch_size, + collate_fn=collate_fn, + num_workers=4, + prefetch_factor=4, + pin_memory=True, + drop_last=False, + shuffle=shuffle) + else: + if self.env_type == 'deepspeed+mpu': + # num_replicas = self.world_size // mpu.get_model_parallel_world_size( + # ) + # rank = self.rank // mpu.get_model_parallel_world_size() + # rank = mpu.get_model_parallel_rank() + rank = mpu.get_model_parallel_src_rank() + print("*"*80) + print("local rank",self.rank, "model rank", rank) + print("*"*80) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + # num_replicas=num_replicas, + rank=rank, + shuffle=shuffle) + else: + num_replicas = self.world_size + rank = self.rank + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, rank=rank, shuffle=shuffle) + return torch.utils.data.DataLoader(dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=4, + drop_last=False, + pin_memory=False, + prefetch_factor=4, + collate_fn=collate_fn) + + def train(self, + model=None, + optimizer=None, + lr_scheduler=None, + train_dataset=None, + valid_dataset=None, + metric_methods=[], + collate_fn=None): + """Training Loops""" + """ + Trainer is a simple but unifed training and eval loop for PyTorch/Deepspeed/Megatron-LM. + Args: + model (`torch.nn.Module`, *optional*): + The model to train, evaluate or use for predictions. + args ([`env_type`]): + The enviroment type for training. Will default to 'pytorch'. + env_type: `pytorch`, `pytorchDDP`, `deepspeed`, `deepspeed+mpu` + pytorch: single node cpu/gpu + pytorchDDP: single-/multi- node gpu + deepspeed: single-/multi- node gpu + deepspeed+mpu: single-/multi- node gpu + train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.DataLoader`, *optional*): + The dataset to use for training. + If it is an `Dataset`, we will create a `DataLoader` with the provided `Dataset` and `collate_fn' for the selected `env_type`. + `Dataset` is prefred to iterally return a sample as followings, + >>> {'text': 'I like big model.', 'label': 'positive'} + If it is an `DataLoader`, we will directly use it. + Important: Columns not accepted by the `model.forward()` method are automatically droped. + eval_dataset (`torch.utils.data.Dataset` or `torch.utils.data.DataLoader`, *optional*): + The dataset to use for evaluation. Similar to `train_dataset`. + collate_fn (`DataCollator` or `function`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. + metrics (`function`, *optional*): + The function that will be used to compute metrics at evaluation. Must return + a dictionary string to metric values. + optimizers (`torch.optim.Optimizer`, *optional*): A optimizer to use. Will default to an instance of + [`AdamW`] on your model. + lr_scheduler (`torch.optim.lr_scheduler`, *optional*): A lr_scheduler to use. Will default to an instance of + [`AnnealingLR`]. + """ + if not isinstance(train_dataset, torch.utils.data.DataLoader): + train_dataloader = self.get_dataloader(train_dataset, collate_fn, + True) + else: + train_dataloader = train_dataset + + if not isinstance(valid_dataset, torch.utils.data.DataLoader): + + valid_dataloader = self.get_dataloader(valid_dataset, collate_fn, + False) + else: + valid_dataloader = valid_dataset + + if self.load_dir: + log_dist("loading checkpoints form {}".format(self.load_dir)) + sd = load_checkpoint(model, + load_dir=self.load_dir, + load_type=self.load_type) + """Train the model.""" + # Turn on training mode which enables dropout. + model.train() + if self.fp16 and self.env_type == 'pytorchDDP': + log_dist( + "Warning: The pytorchDDP plus FP16 may not working togather!!!" + ) + if self.fp16: + model.half() + if self.checkpoint_activations: + model.config[ + 'checkpoint_activations'] = self.checkpoint_activations + + if self.env_type == 'pytorchDDP': + model.to(torch.device('cuda', self.local_rank)) + model = DDP(model, + device_ids=[self.local_rank], + find_unused_parameters=True) + + elif self.env_type == 'pytorch': + model.to(self.pytorch_device) + else: + model.cuda(torch.device('cuda', self.local_rank)) + if self.fp16: + model = FP16_Module(model) + + param_groups = get_optimizer_param_groups(model) + + if hasattr(param_groups[0], 'params'): + # for T5 Model + param_groups = param_groups[0]['params'] + + if optimizer is None and 'deepspeed' not in self.env_type and self.epochs > 0: + optimizer = get_optimizer( + param_groups=param_groups, + lr=self.lr, + weight_decay=self.weight_decay, + cpu_optimizer=False, + cpu_torch_adam=False, + fp16=self.fp16, + optimizer='adam') # if not self.fp16 else 'adafactor') + + if lr_scheduler == None and optimizer != None and self.warm_up > 0 and 'deepspeed' not in self.env_type and self.epochs > 0: + + lr_scheduler = AnnealingLR( + optimizer, + start_lr=self.lr, + warmup_iter=int(self.warm_up * self.epochs * + len(train_dataloader)), + decay_style='linear', + num_iters=self.epochs * len(train_dataloader)) + + if 'deepspeed' in self.env_type: + # initialize the deepspeed + model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model, + # if huggingface t5: param_groups[0]['params'] + model_parameters=param_groups, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + mpu=mpu if self.env_type == 'deepspeed+mpu' else None, + config=self.deepspeed_config, + dist_init_required=True) + if self.load_optim: + print(self.load_optim) + print(type(self.load_optim)) + load_optim(optimizer, lr_scheduler, sd) + if self.load_rng: + load_rng(sd) + # Tracking loss. + total_lm_loss = 0.0 + self.iteration = 0 + self.accumulate_count = 0 + best_iteration = 0 + best_loss = float('inf') + # For each remaining epoch + self.timers('interval time').start() + # self.eval_metrics = eval_metrics + # self.do_eval = valid_dataset!=None + self.metric_methods = metric_methods + best_score = float('inf') + if len(self.metric_methods) > 0: + best_score = -best_score + + for epoch in range(self.epochs): + # log_dist('working on epoch {} ...'.format(epoch), [0]) + # Set the data loader epoch to shuffle the index iterator. + # if self.env_type == 'deepspeed+mpu': + # if mpu.get_model_parallel_rank() == 0: + # train_dataloader.sampler.set_epoch(epoch + self.world_size) + if self.env_type != 'pytorch': + train_dataloader.sampler.set_epoch(epoch + self.world_size) + + # For all the batches in the dataset. + for iteration_, batch in enumerate(train_dataloader): + # Train for one step. + if 'deepspeed' in self.env_type or self.env_type == 'pytorchDDP': + batch = { + x: batch[x].to(torch.device('cuda', self.local_rank)) + for x in batch if x not in ['uid', 'meta', 'mode'] + } + elif 'pytorch' == self.env_type: + batch = { + x: batch[x].to(torch.device(self.pytorch_device)) + for x in batch if x not in ['uid', 'meta', 'mode'] + } + if self.env_type == 'pytorchDDP': + lm_loss, _ = self.train_step_pytorchDDP( + batch, model, optimizer, lr_scheduler) + dist.barrier() + + elif self.env_type == 'pytorch': + lm_loss, _ = self.train_step_pytorch( + batch, model, optimizer, lr_scheduler) + else: + lm_loss, _ = self.train_step_deepspeed(batch, + model, + optimizer, + lr_scheduler, + single_step=True) + dist.barrier() + if lm_loss is not None: + total_lm_loss += lm_loss.data.detach().float() + + # Logging. + if (self.iteration + 1) % self.log_interval == 0: + if optimizer is not None: + learning_rate = optimizer.param_groups[0]['lr'] + else: + learning_rate = model.optimizer.param_groups[0]['lr'] + avg_lm_loss = total_lm_loss.item() / self.log_interval + elapsed_time = self.timers('interval time').elapsed() + self.report_iteration_metrics( + optimizer, learning_rate, avg_lm_loss, + elapsed_time * 1000.0 / self.log_interval, + self.iteration + 1, + self.epochs * len(train_dataloader)) + self.tb_writer.add_scalar('train/loss', avg_lm_loss, + self.iteration + 1) + self.tb_writer.add_scalar('lr', learning_rate, + self.iteration + 1) + total_lm_loss = 0.0 + # Evaluation #todo add train_args + if self.eval_interval and ( + self.iteration + 1 + ) % self.eval_interval == 0 and valid_dataloader is not None: + self.timers.log(['forward', 'backward', 'optimizer'], + normalizer=self.eval_interval) + prefix = 'epoch {}'.format(epoch) + eval_dict = self.evaluate_and_print_results( + prefix=prefix, + data_loader=valid_dataloader, + model=model, + forward_step_func=self.forward_step, + verbose=False) + if eval_dict is not None: + eval_loss = eval_dict.get("loss", 0.0) + self.tb_writer.add_scalar('eval/loss', eval_loss, + self.iteration + 1) + for i in range(len(self.metric_methods)): + name = self.metric_methods[i][0] + score = eval_dict.get(name, 0) + self.tb_writer.add_scalar( + 'eval_metrics/%s' % (name), score, + self.iteration + 1) + + if self.save_best is not None and self.save_best(best_score, eval_dict) != best_score: + best_score = self.save_best(best_score, eval_dict) + log_dist("saving best model with score {:.4f}".format(best_score)) + best_iteration = self.iteration + save_checkpoint(self.iteration+1, + best_iteration+1, + + model, + optimizer, + lr_scheduler, + save_optim=self.save_optim, + save_dir=self.save_dir, + save_rng=self.save_rng) + if self.save_dir and (self.iteration + 1) % self.save_interval == 0 and \ + self.iteration != best_iteration: + save_checkpoint(self.iteration+1, + best_iteration+1, + model, + optimizer, + lr_scheduler, + save_optim=self.save_optim, + save_dir=self.save_dir, + save_rng=self.save_rng) + self.iteration += 1 + + # Checkpointing at the end of each epoch. + + # Evaluation #todo add train_args + if ((self.epochs == 0) or (self.eval_interval and + (self.iteration ) % self.eval_interval != 0) + ) and valid_dataloader is not None: + prefix = 'final evaluate' + self.evaluate_and_print_results( + prefix=prefix, + data_loader=valid_dataloader, + model=model, + forward_step_func=self.forward_step, + verbose=False) + + def train_step_pytorch(self, + data, + model, + optimizer, + lr_scheduler, + mems=None): + """Single training step.""" + # Forward model for one step. + self.timers('forward').start() + step_output = self.forward_step(data, model, mems) + self.timers('forward').stop() + # accumulate gradients + lm_loss = step_output['loss'] + lm_loss /= self.gradient_accumulation_steps + reduced_loss = lm_loss.detach().clone().view(1) + # skip the iter while loss has NAN + if not DynamicLossScaler._has_inf_or_nan(reduced_loss): + # Calculate gradients, reduce across processes, and clip. + self.timers('backward').start() + if self.fp16 and hasattr(optimizer, 'backward'): + optimizer.backward(lm_loss, + update_master_grads=False, + retain_graph=True) + else: + lm_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad) + self.timers('backward').stop() + + # Update parameters. + self.timers('optimizer').start() + if (self.accumulate_count + + 1) % self.gradient_accumulation_steps == 0: + if self.fp16: + # optimizer.update_master_grads() + optimizer.step() + optimizer.zero_grad() + else: + optimizer.step() + # optimizer.zero_grad() + self.accumulate_count = 0 + else: + self.accumulate_count += 1 + if lr_scheduler: + lr_scheduler.step() + self.timers('optimizer').stop() + + else: + log_dist("Found NaN loss, skip backward", [0]) + del lm_loss, reduced_loss + mems = None + reduced_loss = None + return reduced_loss, mems + + def train_step_pytorchDDP(self, + data, + model, + optimizer, + lr_scheduler, + mems=None): + """Single training step.""" + + from contextlib import nullcontext + if self.fp16: + no_sync = model.module.no_sync + else: + no_sync = model.no_sync + + mycontext = no_sync if ( + self.accumulate_count + + 1) != self.gradient_accumulation_steps else nullcontext + + with mycontext(): + # Forward model for one step. + self.timers('forward').start() + step_output = self.forward_step(data, model, mems) + self.timers('forward').stop() + + # accumulate gradients + lm_loss = step_output['loss'] + lm_loss /= self.gradient_accumulation_steps + # reduce sum of losses + reduced_loss = lm_loss.detach().clone().view(1) + # dist.all_reduce(reduced_loss.data) + # reduced_loss.data = reduced_loss.data / self.world_size + + # skip the iter while loss has NAN + if not DynamicLossScaler._has_inf_or_nan(reduced_loss): + # Calculate gradients, reduce across processes, and clip. + self.timers('backward').start() + + if self.fp16 and hasattr(optimizer, 'backward'): + log_dist("The optimizer has backward function") + optimizer.backward(lm_loss, + update_master_grads=False, + retain_graph=True) + else: + lm_loss.backward() + + torch.nn.utils.clip_grad_norm_(model.module.parameters(), + self.clip_grad) + self.timers('backward').stop() + + # Update parameters. + self.timers('optimizer').start() + if (self.accumulate_count + + 1) % self.gradient_accumulation_steps == 0: + if self.fp16: + optimizer.update_master_grads() + optimizer.step() + optimizer.zero_grad() + else: + optimizer.step() + # model.zero_grad() + + self.accumulate_count = 0 + else: + self.accumulate_count += 1 + if lr_scheduler: + lr_scheduler.step() + self.timers('optimizer').stop() + dist.barrier() + + else: + log_dist("Found NaN loss, skip backward", [0]) + del lm_loss, reduced_loss + mems = None + reduced_loss = None + return reduced_loss, mems + + def train_step_deepspeed(self, + data, + model, + optimizer, + lr_scheduler, + mems=None, + single_step=False): + """Single training step.""" + + # Forward model for one step. + if (self.accumulate_count + 1) % self.gradient_accumulation_steps == 0: + model.set_gradient_accumulation_boundary(True) + else: + model.set_gradient_accumulation_boundary(False) + self.timers('forward').start() + step_output = self.forward_step(data, model, mems) + self.timers('forward').stop() + lm_loss = step_output['loss'] + reduced_loss = lm_loss.detach().clone().view(1) + + if self.env_type == 'deepspeed+mpu': + torch.distributed.all_reduce(reduced_loss.data, + group=mpu.get_data_parallel_group()) + elif self.env_type == 'deepspeed': + torch.distributed.all_reduce(reduced_loss.data) + if 'deepspeed' in self.env_type: + reduced_loss.data = reduced_loss.data / \ + (self.world_size / self.model_parallel_size) + if not DynamicLossScaler._has_inf_or_nan(reduced_loss): + # Calculate gradients, reduce across processes, and clip. + self.timers('backward').start() + model.backward(lm_loss) + self.timers('backward').stop() + # Update parameters. + self.timers('optimizer').start() + model.step() + if lr_scheduler: + lr_scheduler.step() + self.timers('optimizer').stop() + if (self.accumulate_count + + 1) % self.gradient_accumulation_steps == 0: + self.accumulate_count = 0 + else: + self.accumulate_count += 1 + dist.barrier() + else: + log_dist("Found NaN loss, skip backward", [0]) + del lm_loss, reduced_loss + mems = [] + reduced_loss = None + return reduced_loss, mems + + def forward_step(self, data, model, mems=None): + """Simple forward step. """ + data['mems'] = mems + model_output = model(**data) + logits = model_output['logits'] + loss = model_output['loss'] + hidden_states = None + if 'hidden_states' in model_output: + hidden_states = model_output['hidden_states'] + elif 'encoder_hidden_states' in model_output: + hidden_states = model_output['encoder_hidden_states'] + + return { + 'loss': loss, + 'hidden_states': hidden_states, + 'logits': logits.contiguous().float() + } + + def backward_step(self, optimizer, model, lm_loss): + """Backward step.""" + + # Total loss. + loss = lm_loss + # Backward pass. + # if self.train_args.deepspeed: + if 'deepspeed' in self.env_type: + model.backward(loss) + else: + # optimizer.zero_grad() + if hasattr(optimizer, 'backward'): + optimizer.backward(loss, update_master_grads=False) + else: + loss.backward() + if self.env_type == 'pytorchDDP': + optimizer.step() + + # if self.train_args.deepspeed or self.train_args.DDP_impl == 'torch': + self.timers('allreduce').reset() + if self.env_type == 'pytorch': + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad) + return lm_loss + + def _gather_all(self, input_): + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size() == 1: + return input_ + # Size and dimension. + last_dim = input_.dim() - 1 + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + tensor_list = [ + torch.empty_like(input_, device=input_.device) + for _ in range(world_size) + ] + tensor_list[rank] = input_ + + torch.distributed.all_gather(tensor_list, input_) + + # Note: torch.cat already creates a contiguous tensor. + if last_dim >= 0: + output = torch.cat(tensor_list, dim=0).contiguous() + else: + output = torch.mean(torch.FloatTensor(tensor_list)) + + return output + + def _gather_all_mpu(self, input_): + group = mpu.get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + # Size and dimension. + last_dim = input_.dim() - 1 + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + + tensor_list = [ + torch.empty_like(input_, device=input_.device) + for _ in range(world_size) + ] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + def evaluate(self, + data_loader=None, + model=None, + forward_step_func=None, + verbose=False): + """Evaluation.""" + + # Turn off checkpoint_activations + tmp_checkpoint_activations = None + tmp_model = model + while hasattr(tmp_model, 'module'): + tmp_model = tmp_model.module + # Turn on evaluation mode which disables dropout. + tmp_model.eval() + if hasattr(tmp_model, + 'config') and 'checkpoint_activations' in tmp_model.config: + tmp_checkpoint_activations = tmp_model.config[ + 'checkpoint_activations'] + tmp_model.config['checkpoint_activations'] = False + + mems = None + metrics = [0. for _ in range(len(self.metric_methods))] + + with torch.no_grad(): + assert data_loader is not None, "val loader is not None." + all_logits = [] + all_labels = [] + all_losses = [] + for data_iterator in data_loader: + # Forward evaluation. + + meta = data_iterator.get('meta', None) + + if 'deepspeed' in self.env_type or 'DDP' in self.env_type: + data_iterator = { + x: data_iterator[x].to( + torch.device('cuda', self.local_rank)) + for x in data_iterator + if x not in ['uid', 'meta', 'mode'] + } + elif torch.cuda.is_available(): + + data_iterator = { + x: + data_iterator[x].to(torch.device(self.pytorch_device)) + for x in data_iterator + if x not in ['uid', 'meta', 'mode'] + } + step_output = forward_step_func(data_iterator, model, mems) + '''when contiguous memory optimizations are enabled, the buffers + allocated by the optimizations are deallocated during backward pass + in the absence of backward pass the buffers should be reset after each + forward pass''' + if 'deepspeed' in self.env_type and self.deepspeed_activation_checkpointing: + deepspeed.checkpointing.reset() + logits = step_output['logits'] + lm_loss = step_output['loss'] + + if 'labels' in data_iterator: + labels = data_iterator['labels'] + else: + labels = data_iterator['target_ids'] + + all_logits.append(logits) + all_labels.append(labels) + all_losses.append(lm_loss.view(1)) + + if len(self.metric_methods) != 0: + all_logits = torch.cat(all_logits, dim=0) + all_labels = torch.cat(all_labels, dim=0) + + all_losses = torch.cat(all_losses, dim=0) + + if self.env_type == 'pytorchDDP' or self.env_type == 'deepspeed': + if len(self.metric_methods) != 0: + all_logits = self._gather_all(all_logits) + all_labels = self._gather_all(all_labels) + all_losses = self._gather_all(all_losses) + + elif self.env_type == 'deepspeed+mpu': + if len(self.metric_methods) != 0: + all_logits = self._gather_all_mpu(all_logits) + all_labels = self._gather_all_mpu(all_labels) + all_losses = self._gather_all_mpu(all_losses) + + if all_losses.device != torch.device('cpu'): + all_losses = all_losses.cpu().detach().numpy()[0] + + for i in range(len(self.metric_methods)): + eval_method = self.metric_methods[i][1] + metrics[i] += eval_method(all_logits, all_labels, meta=meta) + + # Move model back to the train mode. + + # model.train() + tmp_model.train() + # recover the settings for checkpoint_activations + if hasattr(tmp_model, + 'config') and 'checkpoint_activations' in tmp_model.config: + tmp_model.config[ + 'checkpoint_activations'] = tmp_checkpoint_activations + metric_dct = {} + for i in range(len(self.metric_methods)): + metric_name = self.metric_methods[i][0] + metric_dct.update({metric_name: metrics[i]}) + metric_dct.update({"loss": all_losses}) + return metric_dct + + def report_iteration_metrics(self, optimizer, lr, loss, elapsed_time, step, + total_step): + log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time) + log_string += ' learning rate {:.3E} |'.format(lr) + log_string += ' loss {:.6E} |'.format(loss) + if self.fp16: + log_string += ' loss scale {:.1f} |'.format( + optimizer.cur_scale if 'deepspeed' in self.env_type else + hasattr(optimizer, 'loss_scale') and optimizer.loss_scale) + # log_string += ' gradient_accumulation {}/{}'.format(self.accumulate_count, self.gradient_accumulation_steps) + log_dist(log_string, [0]) + + def report_evaluate_metrics(self, prefix, loss, ppl, gpt_loss, bert_loss, + sent_loss, multi_loss, step): + string = ' validation loss at {}'.format(prefix) + string += ' | LM loss: {:.6E}'.format(loss) + string += ' | LM PPL: {:.6E}'.format(ppl) + length = len(string) + 1 + log_dist('-' * 100, [0]) + log_dist('-' * length, [0]) + log_dist(string, [0]) + log_dist('-' * length, [0]) + + def evaluate_and_print_results( + self, + prefix=None, + forward_step_func=None, + data_loader=None, + model=None, + verbose=False, + ): + """Helper function to evaluate and dump results on screen.""" + eval_dict = self.evaluate(forward_step_func=forward_step_func, + data_loader=data_loader, + model=model, + verbose=verbose) + if eval_dict.get("loss", None) is not None: + string = ' validation loss at {} | {:.4f}, '.format( + prefix, eval_dict["loss"]) + # with open("results.txt", "a") as myfile: + # myfile.write(string) + if self.metric_methods is None: + return eval_dict + + for i in range(len(self.metric_methods)): + name = self.metric_methods[i][0] + string += ", {} {:.3f}".format(name, eval_dict[name]) + # string = ' validation loss at {} | {:.4f}, Acc {:.2f}'.format( + # prefix, eval_dict["loss"], eval_dict["metrics"]) + length = len(string) + 1 + log_dist('-' * length, [0]) + log_dist(string, [0]) + log_dist('-' * length, [0]) + return eval_dict \ No newline at end of file diff --git a/flagai/launch.py b/flagai/launch.py index 3dcfe22b..ecba3254 100644 --- a/flagai/launch.py +++ b/flagai/launch.py @@ -74,7 +74,8 @@ def launch_dist(launcher='distributed_deepspeed', hostfile='hostfile', nccl_info=False, training_script='train.py', - training_script_paras=None): + training_script_paras=None, + training_paras=None,): try: resource_pool = fetch_hostfile(hostfile) except: @@ -151,6 +152,9 @@ def launch_dist(launcher='distributed_deepspeed', ] cmd_launch.extend(torch_distributed_args) cmd_launch.append(training_script) + if training_paras: + cmd_launch.extend(training_paras) + cmd_launch.append('--not_call_launch') run_cmd = ' '.join(cmd_launch) log_dist(run_cmd) @@ -196,6 +200,9 @@ def launch_dist(launcher='distributed_deepspeed', if len(training_script_paras) > 0: cmd_launch.extend(training_script_paras) + if training_paras: + cmd_launch.extend(training_paras) + cmd_launch.append('--not_call_launch') run_cmd = ' '.join(cmd_launch) log_dist(run_cmd) @@ -226,6 +233,9 @@ def launch_dist(launcher='distributed_deepspeed', if len(training_script_paras) > 0: cmd_launch.extend(training_script_paras) + if training_paras: + cmd_launch.extend(training_paras) + run_cmd = ' '.join(cmd_launch) log_dist(run_cmd) subprocess.Popen(run_cmd, shell=True) diff --git a/flagai/model/base_model.py b/flagai/model/base_model.py index 5480b73b..c600cfb4 100644 --- a/flagai/model/base_model.py +++ b/flagai/model/base_model.py @@ -9,7 +9,6 @@ from flagai.model.file_utils import _get_model_id, _get_config_path, _get_checkpoint_path, _get_vocab_path, _get_model_files import os - # The base model for models class BaseModel(Module): @@ -59,12 +58,34 @@ def from_pretrain(cls, # downloading the files model: Union[Module, None] if model_id and model_id != "null": + model_files = eval(_get_model_files(model_name)) if not os.path.exists(os.path.join(download_path, 'vocab.txt')): - _get_vocab_path(download_path, "vocab.txt", model_id) + if "vocab.txt" in model_files: + _get_vocab_path(download_path, "vocab.txt", model_id) if not only_download_config and not os.path.exists(os.path.join(download_path, 'config.json')): - model_files = eval(_get_model_files(model_name)) - if 'pytorch_model.bin' in model_files: + if os.getenv('ENV_TYPE') == 'deepspeed+mpu': + model_parallel_size = int(os.getenv("MODEL_PARALLEL_SIZE")) + if model_parallel_size > 1: + # if gpus == nums_of_modelhub_models + # can load + # else need to download the pytorch_model.bin and to recut. + model_hub_parallel_size = 0 + for f in model_files: + if "pytorch_model_" in f: + model_hub_parallel_size += 1 + else: + model_parallel_size = 1 + + if "pytorch_model_01.bin" in model_files and model_parallel_size > 1 and model_hub_parallel_size == model_parallel_size: + # Only to download the model slices(megatron-lm). + for file_to_load in model_files: + if "pytorch_model_" in file_to_load: + _get_checkpoint_path(download_path, + file_to_load, + model_id) + + elif 'pytorch_model.bin' in model_files: checkpoint_path = _get_checkpoint_path(download_path, 'pytorch_model.bin', model_id) diff --git a/flagai/model/vision/helpers.py b/flagai/model/vision/helpers.py new file mode 100644 index 00000000..1e56190d --- /dev/null +++ b/flagai/model/vision/helpers.py @@ -0,0 +1,70 @@ + +import os +if os.getenv('ENV_TYPE') == 'deepspeed': + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint +else: + from torch.utils.checkpoint import checkpoint +import torch +from itertools import chain + +def checkpoint_seq( + functions, + x, + every=1, + flatten=False, + skip_last=False, +): + r"""A helper function for checkpointing sequential models. + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint(run_function(start, end, functions), x) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x + diff --git a/flagai/model/vision/layers/__init__.py b/flagai/model/vision/layers/__init__.py new file mode 100755 index 00000000..7e9e7b19 --- /dev/null +++ b/flagai/model/vision/layers/__init__.py @@ -0,0 +1,42 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvNormAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/flagai/model/vision/layers/activations.py b/flagai/model/vision/layers/activations.py new file mode 100755 index 00000000..e16b3bd3 --- /dev/null +++ b/flagai/model/vision/layers/activations.py @@ -0,0 +1,145 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/flagai/model/vision/layers/activations_jit.py b/flagai/model/vision/layers/activations_jit.py new file mode 100755 index 00000000..b4a51653 --- /dev/null +++ b/flagai/model/vision/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/flagai/model/vision/layers/activations_me.py b/flagai/model/vision/layers/activations_me.py new file mode 100755 index 00000000..9a12bb7e --- /dev/null +++ b/flagai/model/vision/layers/activations_me.py @@ -0,0 +1,218 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + @staticmethod + def symbolic(g, x): + return g.op("Mul", x, g.op("Sigmoid", x)) + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/flagai/model/vision/layers/adaptive_avgmax_pool.py b/flagai/model/vision/layers/adaptive_avgmax_pool.py new file mode 100755 index 00000000..ebc6ada8 --- /dev/null +++ b/flagai/model/vision/layers/adaptive_avgmax_pool.py @@ -0,0 +1,118 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool2d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/flagai/model/vision/layers/attention_pool2d.py b/flagai/model/vision/layers/attention_pool2d.py new file mode 100755 index 00000000..a13a6881 --- /dev/null +++ b/flagai/model/vision/layers/attention_pool2d.py @@ -0,0 +1,131 @@ +""" Attention Pool 2D + +Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. + +Based on idea in CLIP by OpenAI, licensed Apache 2.0 +https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Union, Tuple + +import torch +import torch.nn as nn + +from .helpers import to_2tuple +from .pos_embed import apply_rot_embed, RotaryEmbedding +from .weight_init import trunc_normal_ + + +class RotAttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from + train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW + """ + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + assert embed_dim % num_heads == 0 + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pos_embed = RotaryEmbedding(self.head_dim) + + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + x = x.reshape(B, -1, N).permute(0, 2, 1) + + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + + qc, q = q[:, :, :1], q[:, :, 1:] + sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) + q = apply_rot_embed(q, sin_emb, cos_emb) + q = torch.cat([qc, q], dim=2) + + kc, k = k[:, :, :1], k[:, :, 1:] + k = apply_rot_embed(k, sin_emb, cos_emb) + k = torch.cat([kc, k], dim=2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] + + +class AttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + It was based on impl in CLIP by OpenAI + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. + """ + def __init__( + self, + in_features: int, + feat_size: Union[int, Tuple[int, int]], + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.feat_size = to_2tuple(feat_size) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + + spatial_dim = self.feat_size[0] * self.feat_size[1] + self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + trunc_normal_(self.pos_embed, std=in_features ** -0.5) + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + assert self.feat_size[0] == H + assert self.feat_size[1] == W + x = x.reshape(B, -1, N).permute(0, 2, 1) + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] diff --git a/flagai/model/vision/layers/blur_pool.py b/flagai/model/vision/layers/blur_pool.py new file mode 100755 index 00000000..e73d8863 --- /dev/null +++ b/flagai/model/vision/layers/blur_pool.py @@ -0,0 +1,42 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) diff --git a/flagai/model/vision/layers/bottleneck_attn.py b/flagai/model/vision/layers/bottleneck_attn.py new file mode 100755 index 00000000..c3db464e --- /dev/null +++ b/flagai/model/vision/layers/bottleneck_attn.py @@ -0,0 +1,157 @@ +""" Bottleneck Self Attention (Bottleneck Transformers) + +Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + +@misc{2101.11605, +Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, +Title = {Bottleneck Transformers for Visual Recognition}, +Year = {2021}, +} + +Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + +This impl is a WIP but given that it is based on the ref gist likely not too far off. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple, make_divisible +from .weight_init import trunc_normal_ +from .trace_utils import _assert + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) + + def forward(self, q): + B, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). + num_heads (int): parallel attention heads (default: 4) + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + scale_pos_embed (bool): scale the position embedding as well as Q @ K + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, + qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.pos_embed.height, '') + _assert(W == self.pos_embed.width, '') + + x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v + # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) + + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/flagai/model/vision/layers/cbam.py b/flagai/model/vision/layers/cbam.py new file mode 100755 index 00000000..576a8306 --- /dev/null +++ b/flagai/model/vision/layers/cbam.py @@ -0,0 +1,112 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn +import torch.nn.functional as F + +from .conv_bn_act import ConvNormAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(ChannelAttn, self).__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) + + def forward(self, x): + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * F.sigmoid(x_attn) + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(SpatialAttn, self).__init__() + self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(LightSpatialAttn, self).__init__() + self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class CbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(CbamModule, self).__init__() + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/flagai/model/vision/layers/classifier.py b/flagai/model/vision/layers/classifier.py new file mode 100755 index 00000000..3ac33387 --- /dev/null +++ b/flagai/model/vision/layers/classifier.py @@ -0,0 +1,56 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d + + +def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, use_conv=False): + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_features, num_classes, 1, bias=True) + else: + fc = nn.Linear(num_features, num_classes, bias=True) + return fc + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) + fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + return global_pool, fc + + +class ClassifierHead(nn.Module): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + + def forward(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + if pre_logits: + return x.flatten(1) + else: + x = self.fc(x) + return self.flatten(x) diff --git a/flagai/model/vision/layers/cond_conv2d.py b/flagai/model/vision/layers/cond_conv2d.py new file mode 100755 index 00000000..43654c59 --- /dev/null +++ b/flagai/model/vision/layers/cond_conv2d.py @@ -0,0 +1,123 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + # reshape instead of view to work with channels_last input + x = x.reshape(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/flagai/model/vision/layers/config.py b/flagai/model/vision/layers/config.py new file mode 100755 index 00000000..f07b9d78 --- /dev/null +++ b/flagai/model/vision/layers/config.py @@ -0,0 +1,115 @@ +""" Model / Layer Config singleton state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/flagai/model/vision/layers/conv2d_same.py b/flagai/model/vision/layers/conv2d_same.py new file mode 100755 index 00000000..75f0f98d --- /dev/null +++ b/flagai/model/vision/layers/conv2d_same.py @@ -0,0 +1,42 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/flagai/model/vision/layers/conv_bn_act.py b/flagai/model/vision/layers/conv_bn_act.py new file mode 100755 index 00000000..af010573 --- /dev/null +++ b/flagai/model/vision/layers/conv_bn_act.py @@ -0,0 +1,73 @@ +""" Conv2d + BN + Act + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer + + +class ConvNormAct(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): + super(ConvNormAct, self).__init__() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +ConvBnAct = ConvNormAct + + +class ConvNormActAa(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): + super(ConvNormActAa, self).__init__() + use_aa = aa_layer is not None + + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.aa(x) + return x diff --git a/flagai/model/vision/layers/create_act.py b/flagai/model/vision/layers/create_act.py new file mode 100755 index 00000000..e38f2e03 --- /dev/null +++ b/flagai/model/vision/layers/create_act.py @@ -0,0 +1,148 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Union, Callable, Type + +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. +_has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=F.mish if _has_mish else mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=nn.Mish if _has_mish else Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, + hard_mish=HardMishMe, +) + +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +def get_act_fn(name: Union[Callable, str] = 'relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, Callable): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if not (is_no_jit() or is_exportable()): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if not isinstance(name, str): + # callable, module, etc + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if not (is_no_jit() or is_exportable()): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): + act_layer = get_act_layer(name) + if act_layer is None: + return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/flagai/model/vision/layers/create_attn.py b/flagai/model/vision/layers/create_attn.py new file mode 100755 index 00000000..028c0f75 --- /dev/null +++ b/flagai/model/vision/layers/create_attn.py @@ -0,0 +1,89 @@ +""" Attention Factory + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from functools import partial + +from .bottleneck_attn import BottleneckAttn +from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn +from .squeeze_excite import SEModule, EffectiveSEModule + + +def get_attn(attn_type): + if isinstance(attn_type, torch.nn.Module): + return attn_type + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) + elif attn_type == 'ceca': + module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext + elif attn_type == 'gca': + module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + return module_cls + + +def create_attn(attn_type, channels, **kwargs): + module_cls = get_attn(attn_type) + if module_cls is not None: + # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels + return module_cls(channels, **kwargs) + return None diff --git a/flagai/model/vision/layers/create_conv2d.py b/flagai/model/vision/layers/create_conv2d.py new file mode 100755 index 00000000..ac9489ce --- /dev/null +++ b/flagai/model/vision/layers/create_conv2d.py @@ -0,0 +1,36 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + if 'groups' in kwargs: + groups = kwargs.pop('groups') + if groups == in_channels: + kwargs['depthwise'] = True + else: + assert groups == 1 + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop('groups', 1) + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + return m diff --git a/flagai/model/vision/layers/create_norm_act.py b/flagai/model/vision/layers/create_norm_act.py new file mode 100755 index 00000000..cd15c2f8 --- /dev/null +++ b/flagai/model/vision/layers/create_norm_act.py @@ -0,0 +1,88 @@ +""" NormAct (Normalizaiton + Activation Layer) Factory + +Create norm + act combo modules that attempt to be backwards compatible with separate norm + act +isntances in models. Where these are used it will be possible to swap separate BN + act layers with +combined modules like IABN or EvoNorms. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import types +import functools + +from .evo_norm import * +from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d +from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d +from .inplace_abn import InplaceAbn + +_NORM_ACT_MAP = dict( + batchnorm=BatchNormAct2d, + batchnorm2d=BatchNormAct2d, + groupnorm=GroupNormAct, + layernorm=LayerNormAct, + layernorm2d=LayerNormAct2d, + evonormb0=EvoNorm2dB0, + evonormb1=EvoNorm2dB1, + evonormb2=EvoNorm2dB2, + evonorms0=EvoNorm2dS0, + evonorms0a=EvoNorm2dS0a, + evonorms1=EvoNorm2dS1, + evonorms1a=EvoNorm2dS1a, + evonorms2=EvoNorm2dS2, + evonorms2a=EvoNorm2dS2a, + frn=FilterResponseNormAct2d, + frntlu=FilterResponseNormTlu2d, + inplaceabn=InplaceAbn, + iabn=InplaceAbn, +) +_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} +# has act_layer arg to define act type +_NORM_ACT_REQUIRES_ARG = { + BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} + + +def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): + layer = get_norm_act_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def get_norm_act_layer(norm_layer, act_layer=None): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_act_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '').lower().split('-')[0] + norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, must be a lambda/fn that creates a norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + elif type_name.startswith('layernorm2d'): + norm_act_layer = LayerNormAct2d + elif type_name.startswith('layernorm'): + norm_act_layer = LayerNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: + # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types + norm_act_kwargs.setdefault('act_layer', act_layer) + if norm_act_kwargs: + norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/flagai/model/vision/layers/drop.py b/flagai/model/vision/layers/drop.py new file mode 100755 index 00000000..ae065277 --- /dev/null +++ b/flagai/model/vision/layers/drop.py @@ -0,0 +1,166 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + block_mask = torch.empty_like(x).bernoulli_(gamma) + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.empty_like(x).normal_() + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + + def __init__( + self, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + fast: bool = True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) diff --git a/flagai/model/vision/layers/eca.py b/flagai/model/vision/layers/eca.py new file mode 100755 index 00000000..e29be6ac --- /dev/null +++ b/flagai/model/vision/layers/eca.py @@ -0,0 +1,145 @@ +""" +ECA module from ECAnet + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package +by Chris Ha https://github.com/VRandme + +Original License: + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import math +from torch import nn +import torch.nn.functional as F + + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class EcaModule(nn.Module): + """Constructs an ECA module. + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): + super(EcaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + assert kernel_size % 2 == 1 + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv + y = self.conv(y) + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +EfficientChannelAttn = EcaModule # alias + + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without significantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + super(CecaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + has_act = act_layer is not None + assert kernel_size % 2 == 1 + + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding + self.padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) + # Manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y, (self.padding, self.padding), mode='circular') + y = self.conv(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule diff --git a/flagai/model/vision/layers/evo_norm.py b/flagai/model/vision/layers/evo_norm.py new file mode 100755 index 00000000..b643302c --- /dev/null +++ b/flagai/model/vision/layers/evo_norm.py @@ -0,0 +1,350 @@ +""" EvoNorm in PyTorch + +Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967 +@inproceedings{NEURIPS2020, + author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, + pages = {13539--13550}, + publisher = {Curran Associates, Inc.}, + title = {Evolving Normalization-Activation Layers}, + url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf}, + volume = {33}, + year = {2020} +} + +An attempt at getting decent performing EvoNorms running in PyTorch. +While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm +in terms of memory usage and throughput on GPUs. + +I'm testing these modules on TPU w/ PyTorch XLA. Promising start but +currently working around some issues with builtin torch/tensor.var/std. Unlike +GPU, similar train speeds for EvoNormS variants and BatchNorm. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer +from .trace_utils import _assert + + +def instance_std(x, eps: float = 1e-5): + std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) + return std.expand(x.shape) + + +def instance_std_tpu(x, eps: float = 1e-5): + std = manual_var(x, dim=(2, 3)).add(eps).sqrt() + return std.expand(x.shape) +# instance_std = instance_std_tpu + + +def instance_rms(x, eps: float = 1e-5): + rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype) + return rms.expand(x.shape) + + +def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): + xm = x.mean(dim=dim, keepdim=True) + if diff_sqm: + # difference of squared mean and mean squared, faster on TPU can be less stable + var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0) + else: + var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True) + return var + + +def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): + B, C, H, W = x.shape + x_dtype = x.dtype + _assert(C % groups == 0, '') + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + else: + x = x.reshape(B, groups, C // groups, H, W) + std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + return std.expand(x.shape).reshape(B, C, H, W) + + +def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): + # This is a workaround for some stability / odd behaviour of .var and .std + # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results + B, C, H, W = x.shape + _assert(C % groups == 0, '') + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + var = manual_var(x, dim=-1, diff_sqm=diff_sqm) + else: + x = x.reshape(B, groups, C // groups, H, W) + var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm) + return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W) +#group_std = group_std_tpu # FIXME TPU temporary + + +def group_rms(x, groups: int = 32, eps: float = 1e-5): + B, C, H, W = x.shape + _assert(C % groups == 0, '') + x_dtype = x.dtype + x = x.reshape(B, groups, C // groups, H, W) + rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype) + return rms.expand(x.shape).reshape(B, C, H, W) + + +class EvoNorm2dB0(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.v is not None: + nn.init.ones_(self.v) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.v is not None: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + # var = manual_var(x, dim=(0, 2, 3)).squeeze() + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach() * self.momentum * (n / (n - 1))) + else: + var = self.running_var + left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x) + v = self.v.to(x_dtype).view(v_shape) + right = x * v + instance_std(x, self.eps) + x = x / left.max(right) + return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape) + + +class EvoNorm2dB1(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) + else: + var = self.running_var + var = var.to(x_dtype).view(v_shape) + left = var.add(self.eps).sqrt_() + right = (x + 1) * instance_rms(x, self.eps) + x = x / left.max(right) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dB2(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) + else: + var = self.running_var + var = var.to(x_dtype).view(v_shape) + left = var.add(self.eps).sqrt_() + right = instance_rms(x, self.eps) - x + x = x / left.max(right) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS0(nn.Module): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.v is not None: + nn.init.ones_(self.v) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.v is not None: + v = self.v.view(v_shape).to(x_dtype) + x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS0a(EvoNorm2dS0): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + d = group_std(x, self.groups, self.eps) + if self.v is not None: + v = self.v.view(v_shape).to(x_dtype) + x = x * (x * v).sigmoid() + x = x / d + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS1(nn.Module): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer) + else: + self.act = nn.Identity() + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.pre_act_norm = False + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + x = self.act(x) / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS1a(EvoNorm2dS1): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = self.act(x) / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS2(nn.Module): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer) + else: + self.act = nn.Identity() + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + x = self.act(x) / group_rms(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) + + +class EvoNorm2dS2a(EvoNorm2dS2): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = self.act(x) / group_rms(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) diff --git a/flagai/model/vision/layers/filter_response_norm.py b/flagai/model/vision/layers/filter_response_norm.py new file mode 100755 index 00000000..a66a1cd4 --- /dev/null +++ b/flagai/model/vision/layers/filter_response_norm.py @@ -0,0 +1,68 @@ +""" Filter Response Norm in PyTorch + +Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +import torch.nn as nn + +from .create_act import create_act_layer +from .trace_utils import _assert + + +def inv_instance_rms(x, eps: float = 1e-5): + rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) + return rms.expand(x.shape) + + +class FilterResponseNormTlu2d(nn.Module): + def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): + super(FilterResponseNormTlu2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.rms = rms + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.tau is not None: + nn.init.zeros_(self.tau) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = x * inv_instance_rms(x, self.eps) + x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x + + +class FilterResponseNormAct2d(nn.Module): + def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): + super(FilterResponseNormAct2d, self).__init__() + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer, inplace=inplace) + else: + self.act = nn.Identity() + self.rms = rms + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = x * inv_instance_rms(x, self.eps) + x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return self.act(x) diff --git a/flagai/model/vision/layers/gather_excite.py b/flagai/model/vision/layers/gather_excite.py new file mode 100755 index 00000000..2d60dc96 --- /dev/null +++ b/flagai/model/vision/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/flagai/model/vision/layers/global_context.py b/flagai/model/vision/layers/global_context.py new file mode 100755 index 00000000..de7fb5c1 --- /dev/null +++ b/flagai/model/vision/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/flagai/model/vision/layers/halo_attn.py b/flagai/model/vision/layers/halo_attn.py new file mode 100755 index 00000000..f2ac64f8 --- /dev/null +++ b/flagai/model/vision/layers/halo_attn.py @@ -0,0 +1,233 @@ +""" Halo Self Attention + +Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + +@misc{2103.12731, +Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and + Jonathon Shlens}, +Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, +Year = {2021}, +} + +Status: +This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. +The attention mechanism works but it's slow as implemented. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +from torch import nn +import torch.nn.functional as F + +from .helpers import make_divisible +from .weight_init import trunc_normal_ +from .trace_utils import _assert + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, height, width, dim) + rel_k: (2 * window - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + rel_size = rel_k.shape[0] + win_size = (rel_size + 1) // 2 + + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, rel_size) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, rel_size - W]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, rel_size) + x = x_pad[:, :W, win_size - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + """ + def __init__(self, block_size, win_size, dim_head, scale): + """ + Args: + block_size (int): block size + win_size (int): neighbourhood window size + dim_head (int): attention head dim + scale (float): scale factor (for init) + """ + super().__init__() + self.block_size = block_size + self.dim_head = dim_head + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + + def forward(self, q): + B, BB, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, BB, HW, -1) + return rel_logits + + +class HaloAttn(nn.Module): + """ Halo Attention + + Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) + stride: output stride of the module, query downscaled if > 1 (default: 1). + num_heads: parallel attention heads (default: 8). + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + block_size (int): size of blocks. (default: 8) + halo_size (int): size of halo overlap. (default: 3) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool) : add bias to q, k, and v projections + avg_down (bool): use average pool downsample instead of strided query blocks + scale_pos_embed (bool): scale the position embedding as well as Q @ K + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, + qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + assert stride in (1, 2) + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + self.block_size = self.block_size_ds = block_size + self.halo_size = halo_size + self.win_size = block_size + halo_size * 2 # neighbourhood window size + self.block_stride = 1 + use_avg_pool = False + if stride > 1: + use_avg_pool = avg_down or block_size % stride != 0 + self.block_stride = 1 if use_avg_pool else stride + self.block_size_ds = self.block_size // self.block_stride + + # FIXME not clear if this stride behaviour is what the paper intended + # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving + # data in unfolded block form. I haven't wrapped my head around how that'd look. + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) + + self.pos_embed = PosEmbedRel( + block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + _assert(H % self.block_size == 0, '') + _assert(W % self.block_size == 0, '') + num_h_blocks = H // self.block_size + num_w_blocks = W // self.block_size + num_blocks = num_h_blocks * num_w_blocks + + q = self.q(x) + # unfold + q = q.reshape( + -1, self.dim_head_qk, + num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4) + # B, num_heads * dim_head * block_size ** 2, num_blocks + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) + # B * num_heads, num_blocks, block_size ** 2, dim_head + + kv = self.kv(x) + # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not + # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. + # FIXME figure out how to switch impl between this and conv2d if XLA being used. + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) + kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + + if self.scale_pos_embed: + attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale + else: + attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) + # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks + # fold + out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view( + B, self.dim_out_v, H // self.block_stride, W // self.block_stride) + # B, dim_out, H // block_stride, W // block_stride + out = self.pool(out) + return out + + +""" Three alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if is_xla: + # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is + # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. + WW = self.win_size ** 2 + pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) + kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) + elif self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +""" diff --git a/flagai/model/vision/layers/helpers.py b/flagai/model/vision/layers/helpers.py new file mode 100755 index 00000000..cc54ca7f --- /dev/null +++ b/flagai/model/vision/layers/helpers.py @@ -0,0 +1,31 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v diff --git a/flagai/model/vision/layers/inplace_abn.py b/flagai/model/vision/layers/inplace_abn.py new file mode 100755 index 00000000..a8088933 --- /dev/null +++ b/flagai/model/vision/layers/inplace_abn.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_layer=None): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity', '') + self.act_name = act_layer if act_layer else 'identity' + else: + # convert act layer passed as type to string + if act_layer == nn.ELU: + self.act_name = 'elu' + elif act_layer == nn.LeakyReLU: + self.act_name = 'leaky_relu' + elif act_layer is None or act_layer == nn.Identity: + self.act_name = 'identity' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/flagai/model/vision/layers/lambda_layer.py b/flagai/model/vision/layers/lambda_layer.py new file mode 100755 index 00000000..e50b43c8 --- /dev/null +++ b/flagai/model/vision/layers/lambda_layer.py @@ -0,0 +1,133 @@ +""" Lambda Layer + +Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + +@misc{2102.08602, +Author = {Irwan Bello}, +Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, +Year = {2021}, +} + +Status: +This impl is a WIP. Code snippets in the paper were used as reference but +good chance some details are missing/wrong. + +I've only implemented local lambda conv based pos embeddings. + +For a PyTorch impl that includes other embedding options checkout +https://github.com/lucidrains/lambda-networks + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from torch import nn +import torch.nn.functional as F + +from .helpers import to_2tuple, make_divisible +from .weight_init import trunc_normal_ + + +def rel_pos_indices(size): + size = to_2tuple(size) + pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + rel_pos = pos[:, None, :] - pos[:, :, None] + rel_pos[0] += size[0] - 1 + rel_pos[1] += size[1] - 1 + return rel_pos # 2, H * W, H * W + + +class LambdaLayer(nn.Module): + """Lambda Layer + + Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + + NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. + + The internal dimensions of the lambda module are controlled via the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query (q) and key (k) dimension are determined by + * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None + * q = num_heads * dim_head, k = dim_head + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W + stride (int): output stride of the module, avg pool used if stride == 2 + num_heads (int): parallel attention heads. + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, + qk_ratio=1.0, qkv_bias=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.num_heads = num_heads + self.dim_v = dim_out // num_heads + + self.qkv = nn.Conv2d( + dim, + num_heads * self.dim_qk + self.dim_qk + self.dim_v, + kernel_size=1, bias=qkv_bias) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) + self.norm_v = nn.BatchNorm2d(self.dim_v) + + if r is not None: + # local lambda convolution for pos + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.pos_emb = None + self.rel_pos_indices = None + else: + # relative pos embedding + assert feat_size is not None + feat_size = to_2tuple(feat_size) + rel_size = [2 * s - 1 for s in feat_size] + self.conv_lambda = None + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) + self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in + if self.conv_lambda is not None: + trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) + if self.pos_emb is not None: + trunc_normal_(self.pos_emb, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + M = H * W + qkv = self.qkv(x) + q, k, v = torch.split(qkv, [ + self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K + v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V + k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M + + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + + if self.pos_emb is None: + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + else: + # FIXME relative pos embedding path not fully verified + pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) + position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V + position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V + + out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = self.pool(out) + return out diff --git a/flagai/model/vision/layers/linear.py b/flagai/model/vision/layers/linear.py new file mode 100755 index 00000000..38fe3380 --- /dev/null +++ b/flagai/model/vision/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/flagai/model/vision/layers/median_pool.py b/flagai/model/vision/layers/median_pool.py new file mode 100755 index 00000000..40bd71a7 --- /dev/null +++ b/flagai/model/vision/layers/median_pool.py @@ -0,0 +1,49 @@ +""" Median Pool +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F +from .helpers import to_2tuple, to_4tuple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/flagai/model/vision/layers/mixed_conv2d.py b/flagai/model/vision/layers/mixed_conv2d.py new file mode 100755 index 00000000..fa0ce565 --- /dev/null +++ b/flagai/model/vision/layers/mixed_conv2d.py @@ -0,0 +1,51 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = in_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/flagai/model/vision/layers/ml_decoder.py b/flagai/model/vision/layers/ml_decoder.py new file mode 100755 index 00000000..3f828c6d --- /dev/null +++ b/flagai/model/vision/layers/ml_decoder.py @@ -0,0 +1,156 @@ +from typing import Optional + +import torch +from torch import nn +from torch import nn, Tensor +from torch.nn.modules.transformer import _get_activation_fn + + +def add_ml_decoder_head(model): + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 + model.global_pool = nn.Identity() + del model.fc + num_classes = model.num_classes + num_features = model.num_features + model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet + model.global_pool = nn.Identity() + del model.classifier + num_classes = model.num_classes + num_features = model.num_features + model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') + del model.head + num_classes = model.num_classes + num_features = model.num_features + model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + else: + print("Model code-writing is not aligned currently with ml-decoder") + exit(-1) + if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout + model.drop_rate = 0 + return model + + +class TransformerDecoderLayerOptimal(nn.Module): + def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", + layer_norm_eps=1e-5) -> None: + super(TransformerDecoderLayerOptimal, self).__init__() + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = torch.nn.functional.relu + super(TransformerDecoderLayerOptimal, self).__setstate__(state) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +# @torch.jit.script +# class ExtrapClasses(object): +# def __init__(self, num_queries: int, group_size: int): +# self.num_queries = num_queries +# self.group_size = group_size +# +# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: +# torch.Tensor): +# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) +# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) +# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) +# out = (h * w).sum(dim=2) + class_embed_b +# out = out.view((h.shape[0], self.group_size * self.num_queries)) +# return out + +@torch.jit.script +class GroupFC(object): + def __init__(self, embed_len_decoder: int): + self.embed_len_decoder = embed_len_decoder + + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): + for i in range(self.embed_len_decoder): + h_i = h[:, i, :] + w_i = duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) + + +class MLDecoder(nn.Module): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): + super(MLDecoder, self).__init__() + embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups + if embed_len_decoder > num_classes: + embed_len_decoder = num_classes + + # switching to 768 initial embeddings + decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding + self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) + + # decoder + decoder_dropout = 0.1 + num_layers_decoder = 1 + dim_feedforward = 2048 + layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, + dim_feedforward=dim_feedforward, dropout=decoder_dropout) + self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) + + # non-learnable queries + self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) + self.query_embed.requires_grad_(False) + + # group fully-connected + self.num_classes = num_classes + self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) + self.duplicate_pooling = torch.nn.Parameter( + torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)) + self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) + torch.nn.init.xavier_normal_(self.duplicate_pooling) + torch.nn.init.constant_(self.duplicate_pooling_bias, 0) + self.group_fc = GroupFC(embed_len_decoder) + + def forward(self, x): + if len(x.shape) == 4: # [bs,2048, 7,7] + embedding_spatial = x.flatten(2).transpose(1, 2) + else: # [bs, 197,468] + embedding_spatial = x + embedding_spatial_786 = self.embed_standart(embedding_spatial) + embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True) + + bs = embedding_spatial_786.shape[0] + query_embed = self.query_embed.weight + # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand + h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] + h = h.transpose(0, 1) + + out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + self.group_fc(h, self.duplicate_pooling, out_extrap) + h_out = out_extrap.flatten(1)[:, :self.num_classes] + h_out += self.duplicate_pooling_bias + logits = h_out + return logits diff --git a/flagai/model/vision/layers/mlp.py b/flagai/model/vision/layers/mlp.py new file mode 100755 index 00000000..91e80a84 --- /dev/null +++ b/flagai/model/vision/layers/mlp.py @@ -0,0 +1,126 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .helpers import to_2tuple + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, + norm_layer=None, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/flagai/model/vision/layers/non_local_attn.py b/flagai/model/vision/layers/non_local_attn.py new file mode 100755 index 00000000..670e8f24 --- /dev/null +++ b/flagai/model/vision/layers/non_local_attn.py @@ -0,0 +1,145 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvNormAct +from .helpers import make_divisible +from .trace_utils import _assert + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t: int): + B, C, block_size, block_size1 = x.shape + _assert(block_size == block_size1, '') + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + _assert(x.shape[-1] % self.block_size == 0, '') + _assert(x.shape[-2] % self.block_size == 0, '') + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/flagai/model/vision/layers/norm.py b/flagai/model/vision/layers/norm.py new file mode 100755 index 00000000..85297420 --- /dev/null +++ b/flagai/model/vision/layers/norm.py @@ -0,0 +1,24 @@ +""" Normalization layers and wrappers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x): + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__(num_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) diff --git a/flagai/model/vision/layers/norm_act.py b/flagai/model/vision/layers/norm_act.py new file mode 100755 index 00000000..34c4fd64 --- /dev/null +++ b/flagai/model/vision/layers/norm_act.py @@ -0,0 +1,151 @@ +""" Normalization + Activation Layers +""" +from typing import Union, List + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .trace_utils import _assert +from .create_act import get_act_layer + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Activation in a manner that will remain backwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__( + self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing + _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)') + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + x = F.batch_norm( + x, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + x = self.drop(x) + x = self.act(x) + return x + + +def _num_groups(num_channels, num_groups, group_size): + if group_size: + assert num_channels % group_size == 0 + return num_channels // group_size + return num_groups + + +class GroupNormAct(nn.GroupNorm): + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__( + self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(GroupNormAct, self).__init__( + _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class LayerNormAct(nn.LayerNorm): + def __init__( + self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class LayerNormAct2d(nn.LayerNorm): + def __init__( + self, num_channels, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = self.drop(x) + x = self.act(x) + return x diff --git a/flagai/model/vision/layers/padding.py b/flagai/model/vision/layers/padding.py new file mode 100755 index 00000000..34afc37c --- /dev/null +++ b/flagai/model/vision/layers/padding.py @@ -0,0 +1,56 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/flagai/model/vision/layers/patch_embed.py b/flagai/model/vision/layers/patch_embed.py new file mode 100755 index 00000000..b074798b --- /dev/null +++ b/flagai/model/vision/layers/patch_embed.py @@ -0,0 +1,40 @@ +""" Image to Patch Embedding using Conv2d + +A convolution based approach to patchifying a 2D image w/ embedding projection. + +Based on the impl in https://github.com/google-research/vision_transformer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .helpers import to_2tuple +from .trace_utils import _assert + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x diff --git a/flagai/model/vision/layers/pool2d_same.py b/flagai/model/vision/layers/pool2d_same.py new file mode 100755 index 00000000..4c2a1c44 --- /dev/null +++ b/flagai/model/vision/layers/pool2d_same.py @@ -0,0 +1,73 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + +from .helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' diff --git a/flagai/model/vision/layers/pos_embed.py b/flagai/model/vision/layers/pos_embed.py new file mode 100755 index 00000000..99a122a0 --- /dev/null +++ b/flagai/model/vision/layers/pos_embed.py @@ -0,0 +1,207 @@ +import math +from typing import List, Tuple, Optional, Union + +import torch +from torch import nn as nn + + +def pixel_freq_bands( + num_bands: int, + max_freq: float = 224., + linear_bands: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + if linear_bands: + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) + else: + bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) + return bands * torch.pi + + +def inv_freq_bands( + num_bands: int, + temperature: float = 100000., + step: int = 2, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) + return inv_freq + + +def build_sincos2d_pos_embed( + feat_shape: List[int], + dim: int = 64, + temperature: float = 10000., + reverse_coord: bool = False, + interleave_sin_cos: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None +) -> torch.Tensor: + """ + + Args: + feat_shape: + dim: + temperature: + reverse_coord: stack grid order W, H instead of H, W + interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos + dtype: + device: + + Returns: + + """ + assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' + pos_dim = dim // 4 + bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) + + if reverse_coord: + feat_shape = feat_shape[::-1] # stack W, H instead of H, W + grid = torch.stack( + torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) + pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) + # FIXME add support for unflattened spatial dim? + + stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos + pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) + return pos_emb + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + linear_bands: bool = False, + include_grid: bool = False, + concat_out: bool = True, + in_pixels: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + if bands is None: + if in_pixels: + bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) + else: + bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + if in_pixels: + grid = torch.stack(torch.meshgrid( + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + else: + grid = torch.stack(torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + grid = grid.unsqueeze(-1) + pos = grid * bands + + pos_sin, pos_cos = pos.sin(), pos.cos() + out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) + # FIXME torchscript doesn't like multiple return types, probably need to always cat? + if concat_out: + out = torch.cat(out, dim=-1) + return out + + +class FourierEmbed(nn.Module): + + def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): + super().__init__() + self.max_res = max_res + self.num_bands = num_bands + self.concat_grid = concat_grid + self.keep_spatial = keep_spatial + self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) + + def forward(self, x): + B, C = x.shape[:2] + feat_shape = x.shape[2:] + emb = build_fourier_pos_embed( + feat_shape, + self.bands, + include_grid=self.concat_grid, + dtype=x.dtype, + device=x.device) + emb = emb.transpose(-1, -2).flatten(len(feat_shape)) + batch_expand = (B,) + (-1,) * (x.ndim - 1) + + # FIXME support nD + if self.keep_spatial: + x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) + else: + x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) + x = x.reshape(B, feat_shape.numel(), -1) + + return x + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +def apply_rot_embed_split(x: torch.Tensor, emb): + split = emb.shape[-1] // 2 + return x * emb[:, :split] + rot(x) * emb[:, split:] + + +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_freq: float = 224, + linear_bands: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + """ + NOTE: shape arg should include spatial dim only + """ + feat_shape = torch.Size(feat_shape) + + sin_emb, cos_emb = build_fourier_pos_embed( + feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands, + concat_out=False, device=device, dtype=dtype) + N = feat_shape.numel() + sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + +class RotaryEmbedding(nn.Module): + """ Rotary position embedding + + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + def __init__(self, dim, max_res=224, linear_bands: bool = False): + super().__init__() + self.dim = dim + self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) + + def get_embed(self, shape: List[int]): + return build_rotary_pos_embed(shape, self.bands) + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) diff --git a/flagai/model/vision/layers/selective_kernel.py b/flagai/model/vision/layers/selective_kernel.py new file mode 100755 index 00000000..3d71e3aa --- /dev/null +++ b/flagai/model/vision/layers/selective_kernel.py @@ -0,0 +1,119 @@ +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn + +from .conv_bn_act import ConvNormActAa +from .helpers import make_divisible +from .trace_utils import _assert + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + _assert(x.shape[1] == self.num_paths, '') + x = x.sum(1).mean((2, 3), keepdim=True) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernel(nn.Module): + + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + rd_ratio (int, float): reduction factor for attention features + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + aa_layer (nn.Module): anti-aliasing module + drop_layer (nn.Module): spatial drop module in convs (drop block, etc) + """ + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, drop_layer=drop_layer) + self.paths = nn.ModuleList([ + ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/flagai/model/vision/layers/separable_conv.py b/flagai/model/vision/layers/separable_conv.py new file mode 100755 index 00000000..c081e02b --- /dev/null +++ b/flagai/model/vision/layers/separable_conv.py @@ -0,0 +1,76 @@ +""" Depthwise Separable Conv Modules + +Basic DWS convs. Other variations of DWS exist with batch norm or activations between the +DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer + + +class SeparableConvNormAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, + apply_act=True, drop_layer=None): + super(SeparableConvNormAct, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + x = self.bn(x) + return x + + +SeparableConvBnAct = SeparableConvNormAct + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/flagai/model/vision/layers/space_to_depth.py b/flagai/model/vision/layers/space_to_depth.py new file mode 100755 index 00000000..a7e8e0b2 --- /dev/null +++ b/flagai/model/vision/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x diff --git a/flagai/model/vision/layers/split_attn.py b/flagai/model/vision/layers/split_attn.py new file mode 100755 index 00000000..ac54f898 --- /dev/null +++ b/flagai/model/vision/layers/split_attn.py @@ -0,0 +1,84 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + +from .helpers import make_divisible + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) + """ + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels + self.radix = radix + mid_chs = out_channels * radix + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + + padding = kernel_size // 2 if padding is None else padding + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + x = self.bn0(x) + x = self.drop(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = x.sum(dim=1) + else: + x_gap = x + x_gap = x_gap.mean((2, 3), keepdim=True) + x_gap = self.fc1(x_gap) + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/flagai/model/vision/layers/split_batchnorm.py b/flagai/model/vision/layers/split_batchnorm.py new file mode 100755 index 00000000..830781b3 --- /dev/null +++ b/flagai/model/vision/layers/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/flagai/model/vision/layers/squeeze_excite.py b/flagai/model/vision/layers/squeeze_excite.py new file mode 100755 index 00000000..e5da29ef --- /dev/null +++ b/flagai/model/vision/layers/squeeze_excite.py @@ -0,0 +1,74 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias diff --git a/flagai/model/vision/layers/std_conv.py b/flagai/model/vision/layers/std_conv.py new file mode 100755 index 00000000..d896ba5c --- /dev/null +++ b/flagai/model/vision/layers/std_conv.py @@ -0,0 +1,133 @@ +""" Convolution with Weight Standardization (StdConv and ScaledStdConv) + +StdConv: +@article{weightstandardization, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, + title = {Weight Standardization}, + journal = {arXiv preprint arXiv:1903.10520}, + year = {2019}, +} +Code: https://github.com/joe-siyuan-qiao/WeightStandardization + +ScaledStdConv: +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Hacked together by / copyright Ross Wightman, 2021. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .padding import get_padding, get_padding_value, pad_same + + +class StdConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=False, eps=1e-6): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=False, eps=1e-6): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/flagai/model/vision/layers/test_time_pool.py b/flagai/model/vision/layers/test_time_pool.py new file mode 100755 index 00000000..98c0bf53 --- /dev/null +++ b/flagai/model/vision/layers/test_time_pool.py @@ -0,0 +1,52 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +from torch import nn +import torch.nn.functional as F + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +_logger = logging.getLogger(__name__) + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + + +def apply_test_time_pool(model, config, use_test_size=True): + test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False + if use_test_size and 'test_input_size' in model.default_cfg: + df_input_size = model.default_cfg['test_input_size'] + else: + df_input_size = model.default_cfg['input_size'] + if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: + _logger.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(df_input_size[-2:]))) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/flagai/model/vision/layers/trace_utils.py b/flagai/model/vision/layers/trace_utils.py new file mode 100755 index 00000000..83970729 --- /dev/null +++ b/flagai/model/vision/layers/trace_utils.py @@ -0,0 +1,13 @@ +try: + from torch import _assert +except ImportError: + def _assert(condition: bool, message: str): + assert condition, message + + +def _float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x) diff --git a/flagai/model/vision/layers/weight_init.py b/flagai/model/vision/layers/weight_init.py new file mode 100755 index 00000000..24c0fa7c --- /dev/null +++ b/flagai/model/vision/layers/weight_init.py @@ -0,0 +1,88 @@ +import torch +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/flagai/model/vision/vit.py b/flagai/model/vision/vit.py new file mode 100644 index 00000000..6fc93865 --- /dev/null +++ b/flagai/model/vision/vit.py @@ -0,0 +1,497 @@ +""" +# Copyright © 2022 BAAI. All rights reserved. +""" + +""" +Vision Transformer (ViT) in PyTorch +A PyTorch implement of Vision Transformers as described in: +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 +The official jax code is released and available at https://github.com/google-research/vision_transformer +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert +Hacked together by / Copyright 2020, Ross Wightman +""" + +import math +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from typing import Callable +from flagai.model.vision.layers.patch_embed import PatchEmbed +from flagai.model.vision.layers.mlp import Mlp +from flagai.model.vision.layers.drop import DropPath +from flagai.model.vision.layers.weight_init import trunc_normal_, lecun_normal_ +from flagai.model.base_model import BaseModel +from flagai.model.vision.helpers import checkpoint_seq + +class VitConfig: + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + checkpoint_activations=False): + pass + self.img_size=img_size + self.patch_size=patch_size + self.in_chans=in_chans + self.num_classes=num_classes + self.global_pool=global_pool + self.embed_dim=embed_dim + self.depth=depth + self.num_heads=num_heads + self.mlp_ratio=mlp_ratio + self.qkv_bias=qkv_bias + self.init_values=init_values + self.class_token=class_token + self.fc_norm=fc_norm + self.drop_rate=drop_rate + self.attn_drop_rate=attn_drop_rate + self.drop_path_rate=drop_path_rate + self.weight_init=weight_init + self.checkpoint_activations = checkpoint_activations + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + +class VisionTransformer(BaseModel): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, config, num_classes=1000): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__(config) + embed_layer=PatchEmbed + block_fn=Block + vit_config = VitConfig(**config) + vit_config.num_classes = num_classes + # config = vit_config + + assert vit_config.global_pool in ('', 'avg', 'token') + assert vit_config.class_token or vit_config.global_pool != 'token' + use_fc_norm = vit_config.global_pool == 'avg' if vit_config.fc_norm is None else vit_config.fc_norm + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = vit_config.global_pool + self.num_features = self.embed_dim = vit_config.embed_dim # num_features for consistency with other models + self.num_tokens = 1 if vit_config.class_token else 0 + self.grad_checkpointing = vit_config.checkpoint_activations + + self.patch_embed = embed_layer( + img_size=vit_config.img_size, patch_size=vit_config.patch_size, in_chans=vit_config.in_chans, embed_dim=vit_config.embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, vit_config.embed_dim)) if self.num_tokens > 0 else None + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, vit_config.embed_dim) * .02) + self.pos_drop = nn.Dropout(p=vit_config.drop_rate) + + dpr = [x.item() for x in torch.linspace(0, vit_config.drop_path_rate, vit_config.depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=vit_config.embed_dim, num_heads=vit_config.num_heads, mlp_ratio=vit_config.mlp_ratio, qkv_bias=vit_config.qkv_bias, init_values=vit_config.init_values, + drop=vit_config.drop_rate, attn_drop=vit_config.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(vit_config.depth)]) + self.norm = norm_layer(vit_config.embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(vit_config.embed_dim) if use_fc_norm else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if vit_config.weight_init != 'skip': + self.init_weights(vit_config.weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_weights(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + if self.config["checkpoint_activations"]: + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def compute_loss(self, logits, labels): + loss_func = nn.CrossEntropyLoss() + return loss_func(logits, labels) + + def forward(self, images=None, labels=None, **kwargs): + + x = self.forward_features(images) + x = self.forward_head(x) + loss = None + if labels is not None: + loss = self.compute_loss(x, labels) + return_data = {"logits": x, "hidden_states": x, "loss": loss} + + return return_data + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = ''): + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_init_weights_vit(mode='jax', head_bias: float = 0.): + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + + + diff --git a/flagai/trainer.py b/flagai/trainer.py index 7ad20a1c..f2fd8dd1 100644 --- a/flagai/trainer.py +++ b/flagai/trainer.py @@ -198,22 +198,30 @@ def __init__( self.not_call_launch = True self.deepspeed_config = deepspeed_config self.model_parallel_size = model_parallel_size - if 'deepspeed' in env_type or env_type == 'pytorchDDP': + self.num_nodes = num_nodes + self.num_gpus = num_gpus + self.master_ip = master_ip + self.master_port = master_port + self.hostfile = hostfile + self.training_script = training_script + + if 'deepspeed' in self.env_type or self.env_type == 'pytorchDDP': + training_paras = self.get_dist_args() # Implement for AutoLaunch # >>> python train.py # will call get_dist_args() # `--not_call_launch` is default 'False' # So, if `env_type` is `pytorch`, the `Trainer` will not call lanch_dist() # Otherwise, the lanch_dist() is called to launch 'train.py' with `--not_call_launch` - self.get_dist_args() if not self.not_call_launch: launch_dist(launcher='distributed_deepspeed' if 'deepspeed' in env_type else 'distributed_torch', - num_nodes=num_nodes, - gpus_per_node=num_gpus, - master_addr=master_ip, - master_port=master_port, - hostfile=hostfile, - training_script=training_script) + num_nodes=self.num_nodes, + gpus_per_node=self.num_gpus, + master_addr=self.master_ip, + master_port=self.master_port, + hostfile=self.hostfile, + training_script=self.training_script, + training_paras=training_paras) os._exit(1) self.initialize_distributed() @@ -239,6 +247,7 @@ def get_dist_args(self): self.master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1') self.master_port = os.environ.get('MASTER_PORT', '17500') log_dist("not_call_launch: {}".format(ds_args.not_call_launch)) + return [] def set_seed(self, seed=1234): """Set random seed for reproducability.""" @@ -310,6 +319,9 @@ def get_dataloader(self, dataset, collate_fn, shuffle=False): else: if self.env_type == 'deepspeed+mpu': rank = mpu.get_model_parallel_src_rank() + print("*"*80) + print("local rank",self.rank, "model rank", rank) + print("*"*80) sampler = torch.utils.data.distributed.DistributedSampler( dataset, # num_replicas=num_replicas, @@ -336,7 +348,8 @@ def train(self, train_dataset=None, valid_dataset=None, metric_methods=[], - collate_fn=None): + collate_fn=None, + find_unused_parameters=True): """Training Loops""" """ Trainer is a simple but unifed training and eval loop for PyTorch/Deepspeed/Megatron-LM. @@ -404,7 +417,7 @@ def train(self, model.to(torch.device('cuda', self.local_rank)) model = DDP(model, device_ids=[self.local_rank], - find_unused_parameters=True) + find_unused_parameters=find_unused_parameters) elif self.env_type == 'pytorch': model.to(self.pytorch_device) @@ -506,7 +519,8 @@ def train(self, lr_scheduler, single_step=True) dist.barrier() - total_lm_loss += lm_loss.data.detach().float() + if lm_loss is not None: + total_lm_loss += lm_loss.data.detach().float() # Logging. if (self.iteration + 1) % self.log_interval == 0: @@ -1025,3 +1039,4 @@ def evaluate_and_print_results( log_dist(string, [0]) log_dist('-' * length, [0]) return eval_dict + diff --git a/flagai/utils.py b/flagai/utils.py index 64845757..06c30026 100644 --- a/flagai/utils.py +++ b/flagai/utils.py @@ -206,8 +206,7 @@ def save_checkpoint(iteration, sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() if env_type == 'pytorch' or (env_type != 'deepspeed+mpu' and dist.get_rank() == 0) or ( - env_type == 'deepspeed+mpu' - and mpu.get_model_parallel_src_rank() == 0): + env_type == 'deepspeed+mpu'and mpu.get_model_parallel_src_rank() == 0): ensure_directory_exists(checkpoint_name) config_path = os.path.join(save_dir, str(iteration), 'config.json') @@ -220,6 +219,7 @@ def save_checkpoint(iteration, tracker_filename = get_checkpoint_tracker_filename(save_dir) with open(tracker_filename, 'w') as f: f.write(str(iteration) + '\t' + str(best_iteration)) + # Wait so everyone is done (necessary) if barrier and dist.is_initialized(): torch.distributed.barrier() diff --git a/flagai_wechat.png b/flagai_wechat.png index 387bded2..e9dd1d30 100644 Binary files a/flagai_wechat.png and b/flagai_wechat.png differ diff --git a/setup.py b/setup.py index 33c2090b..b36d19f8 100644 --- a/setup.py +++ b/setup.py @@ -5,8 +5,7 @@ setup( name="flagai", - version="v1.1.3", - + version="v1.2.0", description="FlagAI aims to help researchers and developers to freely train and test large-scale models for NLP tasks.", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown",