diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index 7488487c93..b9e2aad745 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -1,6 +1,7 @@ import gzip import hashlib import os +import sys import shutil import tarfile import tempfile @@ -9,11 +10,12 @@ import numpy as np import torch -from torch._six import PY3 from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg +PY3 = sys.version_info[0] == 3 + if _TORCHVISION_AVAILABLE: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 0c345cc941..ff95063aad 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -27,7 +27,7 @@ class AE(pl.LightningModule): ae = AE() # pretrained on cifar10 - ae = AE.from_pretrained('cifar10-resnet18') + ae = AE(input_height=32).from_pretrained('cifar10-resnet18') """ pretrained_urls = { diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index e0cb157bc3..0b2d45f09d 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -27,10 +27,10 @@ class VAE(pl.LightningModule): vae = VAE() # pretrained on cifar10 - vae = VAE.from_pretrained('cifar10-resnet18') + vae = VAE(input_height=32).from_pretrained('cifar10-resnet18') # pretrained on stl10 - vae = VAE.from_pretrained('stl10-resnet18') + vae = VAE(input_height=32).from_pretrained('stl10-resnet18') """ pretrained_urls = {