Skip to content

Commit

Permalink
Minor fixes on checkpoint saving & loading, training arguments, and t…
Browse files Browse the repository at this point in the history
…he training pipeline.
  • Loading branch information
Bostoncake committed Sep 12, 2024
1 parent 03661cf commit 3587ee6
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ work_dir
saves
ckpt_saves
weights
logs
outputs
data/vtab-1k
32 changes: 1 addition & 31 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
amp: bool = True, teacher_model: torch.nn.Module = None,
teach_loss: torch.nn.Module = None, use_tome=False,tome_initialized=False,tome_r=None,
test_batch_size=512,
teach_loss: torch.nn.Module = None,
deit=False):

if use_tome:
assert type(tome_r) is list

model.train()
criterion.train()
Expand Down Expand Up @@ -95,32 +91,6 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
loss = criterion(outputs, targets)
else:

if not use_tome and not tome_initialized:
# latency test for no ToMe ViT
test_model_latency(model, device, test_batch_size)

# count flops when no tome attached
count_model_flops(model, device)

model.train()
tome_initialized = True

# use ToMe for training
if use_tome and not tome_initialized:
print("ToMe model initialization.")
model_module = unwrap_model(model)
apply_tome(model_module)
model_module.r = tome_r

# latency test
test_model_latency(model, device, test_batch_size)

# count flops
count_model_flops(model, device)

model.train()
tome_initialized = True

if not deit:
outputs = model(samples)
else:
Expand Down
42 changes: 21 additions & 21 deletions lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from timm.data import create_transform

class general_dataset(ImageFolder):
def __init__(self, root, train=True, transform=None, target_transform=None, is_individual_prompt=False,**kwargs):
def __init__(self, root, train=True, transform=None, target_transform=None,**kwargs):
self.dataset_root = root
self.loader = default_loader
self.target_transform = None
Expand All @@ -36,65 +36,65 @@ def __init__(self, root, train=True, transform=None, target_transform=None, is_i
label = int(line.split(' ')[1])
self.samples.append((os.path.join(root,img_name), label))

def build_dataset(is_train, args, folder_name=None,is_individual_prompt=False):
def build_dataset(is_train, args, folder_name=None):
transform = build_transform(is_train, args)

if args.data_set == 'clevr_count':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 8
elif args.data_set == 'diabetic_retinopathy':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 5
elif args.data_set == 'dsprites_loc':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 16
elif args.data_set == 'dtd':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 47
elif args.data_set == 'kitti':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 4
elif args.data_set == 'oxford_pet':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 37
elif args.data_set == 'resisc45':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 45
elif args.data_set == 'smallnorb_ele':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 9
elif args.data_set == 'svhn':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 10
elif args.data_set == 'cifar100':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'clevr_dist':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 6
elif args.data_set == 'caltech101':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 102
elif args.data_set == 'dmlab':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 6
elif args.data_set == 'dsprites_ori':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 16
elif args.data_set == 'eurosat':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 10
elif args.data_set == 'oxford_flowers102':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 102
elif args.data_set == 'patch_camelyon':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 2
elif args.data_set == 'smallnorb_azi':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 18
elif args.data_set == 'sun397':
dataset = general_dataset(args.data_path, train=is_train, transform=transform,is_individual_prompt=is_individual_prompt)
dataset = general_dataset(args.data_path, train=is_train, transform=transform)
nb_classes = 397

return dataset, nb_classes
Expand Down
8 changes: 4 additions & 4 deletions scripts/pyra_vit-l.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ do
do
for LR in 0.001
do
LOG_DIR=logs/${currenttime}_${CONFIG}_compress_${merge_schedule}_PYRA_LR_${pyra_lr}
LOG_DIR=outputs/${currenttime}_${CONFIG}_compress_${merge_schedule}_PYRA_LR_${pyra_lr}

TARGET_DIR=${LOG_DIR}/${DATASET}_lr-${LR}_wd-${WEIGHT_DECAY}
if [ ! -d ${TARGET_DIR} ]
then
mkdir ${LOG_DIR}
mkdir ${TARGET_DIR}
mkdir -p ${LOG_DIR}
mkdir -p ${TARGET_DIR}
else
echo "Dir already exists, skipping ${TARGET_DIR}"
continue
fi
CUDA_VISIBLE_DEVICES=${device} python train.py --data-path=./data/vtab-1k/${DATASET} --data-set=${DATASET}\
--cfg=${CONFIG_DIR} --resume=${CKPT} --output_dir=${TARGET_DIR}\
--batch-size=32 --lr=${LR} --epochs=100 --is_LoRA --weight-decay=${WEIGHT_DECAY}\
--batch-size=32 --lr=${LR} --epochs=100 --weight-decay=${WEIGHT_DECAY}\
--no_aug --mixup=0 --cutmix=0 --direct_resize --smoothing=0\
--token_merging --merging_schedule=${merge_schedule}\
--pyra --separate_lr_for_pyra --pyra_lr=${pyra_lr}\
Expand Down
69 changes: 39 additions & 30 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from timm.utils import NativeScaler
from lib.datasets import build_dataset
from engine import train_one_epoch, evaluate
from engine import train_one_epoch, evaluate, test_model_latency, count_model_flops
from lib.samplers import RASampler
from lib import utils
from lib.config import cfg, update_config_from_file
Expand All @@ -30,7 +30,7 @@
from mmcv.runner import get_dist_info, init_dist

# tome
from model.tome import parse_r, get_merging_schedule
from model.tome import parse_r, get_merging_schedule, apply_tome
from timm.utils.model import unwrap_model

import os
Expand Down Expand Up @@ -258,12 +258,6 @@ def get_args_parser():
default='none',
help='job launcher')

parser.add_argument('--is_visual_prompt_tuning', action='store_true')
parser.add_argument('--is_adapter', action='store_true')
parser.add_argument('--is_LoRA', action='store_true')
parser.add_argument('--is_prefix', action='store_true')
parser.add_argument('--is_consolidator', action='store_true')

parser.add_argument('--no_aug', action='store_true')

parser.add_argument('--val_interval', default=10, type=int, help='validataion interval')
Expand Down Expand Up @@ -315,8 +309,8 @@ def main(args):
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args,is_individual_prompt=True)
dataset_val, _ = build_dataset(is_train=False, args=args,is_individual_prompt=True)
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)

if args.distributed:
num_tasks = utils.get_world_size()
Expand Down Expand Up @@ -389,12 +383,25 @@ def main(args):
pyra_r = pyra_r
)

# initialize token merging
if args.token_merging:
print("Token merging initialization.")
model_module = unwrap_model(model)
apply_tome(model_module)
model_module.r = tome_r

if args.resume:
if 'pth' in args.resume :
incompatible_keys = timm_load_checkpoint(model, args.resume,strict=False)
print(incompatible_keys)
if args.nb_classes != model.head.weight.shape[0]:
model.reset_classifier(args.nb_classes)
if 'pth' in args.resume:
if "best_checkpoint" in args.resume:
if args.nb_classes != model.head.weight.shape[0]:
model.reset_classifier(args.nb_classes)
incompatible_keys = timm_load_checkpoint(model, args.resume,strict=True)
print(incompatible_keys)
else:
incompatible_keys = timm_load_checkpoint(model, args.resume,strict=False)
print(incompatible_keys)
if args.nb_classes != model.head.weight.shape[0]:
model.reset_classifier(args.nb_classes)
# print("Try without loading pth")
else:
load_checkpoint(model, args.resume)
Expand All @@ -421,7 +428,10 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module


# test model latency and count model FLOPS
model.to(device)
test_model_latency(model, device, args.test_batch_size)
count_model_flops(model, device)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
Expand All @@ -448,7 +458,7 @@ def main(args):
f.write(args_text)

if args.eval:
test_stats = evaluate(data_loader_val, model, device)
test_stats = evaluate(data_loader_val, model, device, amp=args.amp)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

Expand All @@ -467,28 +477,16 @@ def main(args):
args.clip_grad, model_ema, mixup_fn,
amp=args.amp, teacher_model=teacher_model,
teach_loss=teacher_loss,
use_tome = args.token_merging, tome_initialized=tome_initialized, tome_r=tome_r,
test_batch_size=args.test_batch_size,
deit="deit" in cfg.MODEL_NAME
)
tome_initialized = True

lr_scheduler.step(epoch)
if args.output_dir and args.save_ckpt:
checkpoint_paths = [output_dir / 'checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)

if epoch % args.val_interval == 0 or epoch == args.epochs-1:
test_stats = evaluate(data_loader_val, model, device, amp=args.amp)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
prev_max_accuracy = max_accuracy
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')

Expand All @@ -500,6 +498,17 @@ def main(args):
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
if args.save_ckpt and max_accuracy > prev_max_accuracy:
checkpoint_paths = [output_dir / 'best_checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
Expand Down

0 comments on commit 3587ee6

Please sign in to comment.