Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lower PyTorch & CuDNN versions #25

Merged
merged 2 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions main_landec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time
import torch
if torch.backends.cudnn.version() < 8000:
torch.backends.cudnn.benchmark = True
# torch.multiprocessing.set_sharing_strategy('file_system')
import resource
import argparse
Expand Down Expand Up @@ -51,6 +53,8 @@
parser.add_argument('--encoder-only', action='store_true', default=False,
help='Only train the encoder. ENet trains encoder and decoder separately (default: False)')
args = parser.parse_args()
if args.mixed_precision and torch.__version__ < '1.6.0':
print('PyTorch version too low, mixed precision training is not available.')
exp_name = str(time.time()) if args.exp_name == '' else args.exp_name
states = ['train', 'valfast', 'test', 'val']
with open(exp_name + '_' + states[args.state] + '_cfg.txt', 'w') as f:
Expand Down
4 changes: 4 additions & 0 deletions main_semseg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import time
import torch
if torch.backends.cudnn.version() < 8000:
torch.backends.cudnn.benchmark = True
import argparse
import math
import yaml
Expand Down Expand Up @@ -43,6 +45,8 @@
parser.add_argument('--encoder-only', action='store_true', default=False,
help='Only train the encoder. ENet trains encoder and decoder separately (default: False)')
args = parser.parse_args()
if args.mixed_precision and torch.__version__ < '1.6.0':
print('PyTorch version too low, mixed precision training is not available.')
exp_name = str(time.time()) if args.exp_name == '' else args.exp_name
with open(exp_name + '_cfg.txt', 'w') as f:
f.write(str(vars(args)))
Expand Down
5 changes: 4 additions & 1 deletion tools/vis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import numpy as np
import cv2
import torch
from torch.cuda.amp import autocast
if torch.__version__ >= '1.6.0':
from torch.cuda.amp import autocast
else:
from utils.torch_amp_dummy import autocast
from enum import Enum
from PIL import Image
from transforms import ToTensor, Resize, ZeroPad, Normalize, Compose
Expand Down
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from . import losses
from . import all_utils_semseg
from . import all_utils_landec
from . import torch_amp_dummy
5 changes: 4 additions & 1 deletion utils/all_utils_landec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import ujson as json
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
if torch.__version__ >= '1.6.0':
from torch.cuda.amp import autocast, GradScaler
else:
from .torch_amp_dummy import autocast, GradScaler
from torchvision_models.segmentation import erfnet_resnet, deeplabv1_vgg16, deeplabv1_resnet18, deeplabv1_resnet34, \
deeplabv1_resnet50, deeplabv1_resnet101, enet_
from torchvision_models.lane_detection import LSTR
Expand Down
5 changes: 4 additions & 1 deletion utils/all_utils_semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from collections import OrderedDict
import torch
import warnings
from torch.cuda.amp import autocast, GradScaler
if torch.__version__ >= '1.6.0':
from torch.cuda.amp import autocast, GradScaler
else:
from .torch_amp_dummy import autocast, GradScaler
from tqdm import tqdm
from torchvision_models.segmentation import deeplabv2_resnet101, deeplabv3_resnet101, fcn_resnet101, erfnet_resnet, \
enet_
Expand Down
39 changes: 39 additions & 0 deletions utils/torch_amp_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Provide a dummy torch.amp utils for better coverage of lower PyTorch versions
import functools


class autocast(object):
def __init__(self, enabled=True):
pass

def __enter__(self):
pass

def __exit__(self, *args):
pass

def __call__(self, func):
@functools.wraps(func)
def decorate_autocast(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_autocast


class GradScaler(object):
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
pass

def scale(self, outputs):
return outputs

def step(self, optimizer, *args, **kwargs):
return optimizer.step(*args, **kwargs)

def update(self):
pass
3 changes: 2 additions & 1 deletion visualize_lane.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import argparse
import torch
import cv2
from torch.cuda.amp import autocast
from tqdm import tqdm
from cv2 import VideoWriter_fourcc
from mmcv import VideoReader
Expand Down Expand Up @@ -55,6 +54,8 @@
parser.add_argument('--workers', type=int, default=0,
help='Number of workers (default: 0)')
args = parser.parse_args()
if args.mixed_precision and torch.__version__ < '1.6.0':
print('PyTorch version too low, mixed precision training is not available.')
with open('configs.yaml', 'r') as f: # Safer and cleaner than box/EasyDict
configs = yaml.load(f, Loader=yaml.Loader)

Expand Down
7 changes: 6 additions & 1 deletion visualize_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import cv2
from cv2 import VideoWriter_fourcc
from mmcv import VideoReader
from torch.cuda.amp import autocast
if torch.__version__ >= '1.6.0':
from torch.cuda.amp import autocast
else:
from utils.torch_amp_dummy import autocast
from PIL import Image
from tqdm import tqdm
from utils.all_utils_semseg import load_checkpoint, build_segmentation_model
Expand Down Expand Up @@ -49,6 +52,8 @@
parser.add_argument('--workers', type=int, default=0,
help='Number of workers (default: 0)')
args = parser.parse_args()
if args.mixed_precision and torch.__version__ < '1.6.0':
print('PyTorch version too low, mixed precision training is not available.')
with open('configs.yaml', 'r') as f: # Safer and cleaner than box/EasyDict
configs = yaml.load(f, Loader=yaml.Loader)

Expand Down