diff --git a/main_landec.py b/main_landec.py index 9cc6ef95..9787c47d 100644 --- a/main_landec.py +++ b/main_landec.py @@ -1,5 +1,7 @@ import time import torch +if torch.backends.cudnn.version() < 8000: + torch.backends.cudnn.benchmark = True # torch.multiprocessing.set_sharing_strategy('file_system') import resource import argparse @@ -51,6 +53,8 @@ parser.add_argument('--encoder-only', action='store_true', default=False, help='Only train the encoder. ENet trains encoder and decoder separately (default: False)') args = parser.parse_args() + if args.mixed_precision and torch.__version__ < '1.6.0': + print('PyTorch version too low, mixed precision training is not available.') exp_name = str(time.time()) if args.exp_name == '' else args.exp_name states = ['train', 'valfast', 'test', 'val'] with open(exp_name + '_' + states[args.state] + '_cfg.txt', 'w') as f: diff --git a/main_semseg.py b/main_semseg.py index b8220afe..c9b7081b 100644 --- a/main_semseg.py +++ b/main_semseg.py @@ -1,6 +1,8 @@ import os import time import torch +if torch.backends.cudnn.version() < 8000: + torch.backends.cudnn.benchmark = True import argparse import math import yaml @@ -43,6 +45,8 @@ parser.add_argument('--encoder-only', action='store_true', default=False, help='Only train the encoder. ENet trains encoder and decoder separately (default: False)') args = parser.parse_args() + if args.mixed_precision and torch.__version__ < '1.6.0': + print('PyTorch version too low, mixed precision training is not available.') exp_name = str(time.time()) if args.exp_name == '' else args.exp_name with open(exp_name + '_cfg.txt', 'w') as f: f.write(str(vars(args))) diff --git a/tools/vis_tools.py b/tools/vis_tools.py index e783c85d..c40717ef 100644 --- a/tools/vis_tools.py +++ b/tools/vis_tools.py @@ -3,7 +3,10 @@ import numpy as np import cv2 import torch -from torch.cuda.amp import autocast +if torch.__version__ >= '1.6.0': + from torch.cuda.amp import autocast +else: + from utils.torch_amp_dummy import autocast from enum import Enum from PIL import Image from transforms import ToTensor, Resize, ZeroPad, Normalize, Compose diff --git a/utils/__init__.py b/utils/__init__.py index 53ec5366..dbf23bae 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -2,3 +2,4 @@ from . import losses from . import all_utils_semseg from . import all_utils_landec +from . import torch_amp_dummy diff --git a/utils/all_utils_landec.py b/utils/all_utils_landec.py index e34ac44c..b39dcb44 100644 --- a/utils/all_utils_landec.py +++ b/utils/all_utils_landec.py @@ -5,7 +5,10 @@ import ujson as json import numpy as np from tqdm import tqdm -from torch.cuda.amp import autocast, GradScaler +if torch.__version__ >= '1.6.0': + from torch.cuda.amp import autocast, GradScaler +else: + from .torch_amp_dummy import autocast, GradScaler from torchvision_models.segmentation import erfnet_resnet, deeplabv1_vgg16, deeplabv1_resnet18, deeplabv1_resnet34, \ deeplabv1_resnet50, deeplabv1_resnet101, enet_ from torchvision_models.lane_detection import LSTR diff --git a/utils/all_utils_semseg.py b/utils/all_utils_semseg.py index 10a4dd19..74a31fb5 100644 --- a/utils/all_utils_semseg.py +++ b/utils/all_utils_semseg.py @@ -2,7 +2,10 @@ from collections import OrderedDict import torch import warnings -from torch.cuda.amp import autocast, GradScaler +if torch.__version__ >= '1.6.0': + from torch.cuda.amp import autocast, GradScaler +else: + from .torch_amp_dummy import autocast, GradScaler from tqdm import tqdm from torchvision_models.segmentation import deeplabv2_resnet101, deeplabv3_resnet101, fcn_resnet101, erfnet_resnet, \ enet_ diff --git a/utils/torch_amp_dummy.py b/utils/torch_amp_dummy.py new file mode 100644 index 00000000..58f657fb --- /dev/null +++ b/utils/torch_amp_dummy.py @@ -0,0 +1,39 @@ +# Provide a dummy torch.amp utils for better coverage of lower PyTorch versions +import functools + + +class autocast(object): + def __init__(self, enabled=True): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + def __call__(self, func): + @functools.wraps(func) + def decorate_autocast(*args, **kwargs): + with self: + return func(*args, **kwargs) + return decorate_autocast + + +class GradScaler(object): + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + pass + + def scale(self, outputs): + return outputs + + def step(self, optimizer, *args, **kwargs): + return optimizer.step(*args, **kwargs) + + def update(self): + pass diff --git a/visualize_lane.py b/visualize_lane.py index ad41e0d7..70be13ae 100644 --- a/visualize_lane.py +++ b/visualize_lane.py @@ -3,7 +3,6 @@ import argparse import torch import cv2 -from torch.cuda.amp import autocast from tqdm import tqdm from cv2 import VideoWriter_fourcc from mmcv import VideoReader @@ -55,6 +54,8 @@ parser.add_argument('--workers', type=int, default=0, help='Number of workers (default: 0)') args = parser.parse_args() + if args.mixed_precision and torch.__version__ < '1.6.0': + print('PyTorch version too low, mixed precision training is not available.') with open('configs.yaml', 'r') as f: # Safer and cleaner than box/EasyDict configs = yaml.load(f, Loader=yaml.Loader) diff --git a/visualize_segmentation.py b/visualize_segmentation.py index 2d76bb16..4383d884 100644 --- a/visualize_segmentation.py +++ b/visualize_segmentation.py @@ -5,7 +5,10 @@ import cv2 from cv2 import VideoWriter_fourcc from mmcv import VideoReader -from torch.cuda.amp import autocast +if torch.__version__ >= '1.6.0': + from torch.cuda.amp import autocast +else: + from utils.torch_amp_dummy import autocast from PIL import Image from tqdm import tqdm from utils.all_utils_semseg import load_checkpoint, build_segmentation_model @@ -49,6 +52,8 @@ parser.add_argument('--workers', type=int, default=0, help='Number of workers (default: 0)') args = parser.parse_args() + if args.mixed_precision and torch.__version__ < '1.6.0': + print('PyTorch version too low, mixed precision training is not available.') with open('configs.yaml', 'r') as f: # Safer and cleaner than box/EasyDict configs = yaml.load(f, Loader=yaml.Loader)