Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-769] set MXNET_HOME as base for downloaded models through base.data_dir() #11636

Merged
merged 2 commits into from
Aug 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/docker_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import subprocess
import json
import build as build_util
from joblib import Parallel, delayed



Expand All @@ -43,6 +42,7 @@ def build_save_containers(platforms, registry, load_cache) -> int:
:param load_cache: Load cache before building
:return: 1 if error occurred, 0 otherwise
"""
from joblib import Parallel, delayed
if len(platforms) == 0:
return 0

Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/examples/scripts/get_cifar_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/examples/scripts/get_mnist_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/scripts/get_cifar_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/scripts/get_mnist_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 4 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
- Values: String ```(default='https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'```
- The repository url to be used for Gluon datasets and pre-trained models.

* MXNET_HOME
- Data directory in the filesystem for storage, for example when downloading gluon models.
- Default in *nix is .mxnet APPDATA/mxnet in windows.

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
24 changes: 22 additions & 2 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

import atexit
import ctypes
import inspect
import os
import sys
import warnings

import inspect
import platform
import numpy as np

from . import libinfo
Expand Down Expand Up @@ -59,6 +59,26 @@
py_str = lambda x: x


def data_dir_default():
"""

:return: default data directory depending on the platform and environment variables
"""
system = platform.system()
if system == 'Windows':
return os.path.join(os.environ.get('APPDATA'), 'mxnet')
else:
return os.path.join(os.path.expanduser("~"), '.mxnet')


def data_dir():
"""

:return: data directory in the filesystem for storage, for example when downloading models
"""
return os.getenv('MXNET_HOME', data_dir_default())


class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
Expand Down
9 changes: 5 additions & 4 deletions python/mxnet/contrib/text/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import vocab
from ... import ndarray as nd
from ... import registry
from ... import base


def register(embedding_cls):
Expand Down Expand Up @@ -496,7 +497,7 @@ class GloVe(_TokenEmbedding):
----------
pretrained_file_name : str, default 'glove.840B.300d.txt'
The name of the pre-trained token embedding file.
embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
Expand Down Expand Up @@ -541,7 +542,7 @@ def _get_download_file_name(cls, pretrained_file_name):
return archive

def __init__(self, pretrained_file_name='glove.840B.300d.txt',
embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
GloVe._check_pretrained_file_names(pretrained_file_name)

Expand Down Expand Up @@ -600,7 +601,7 @@ class FastText(_TokenEmbedding):
----------
pretrained_file_name : str, default 'wiki.en.vec'
The name of the pre-trained token embedding file.
embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
Expand Down Expand Up @@ -642,7 +643,7 @@ def _get_download_file_name(cls, pretrained_file_name):
return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip'

def __init__(self, pretrained_file_name='wiki.simple.vec',
embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
FastText._check_pretrained_file_names(pretrained_file_name)

Expand Down
11 changes: 5 additions & 6 deletions python/mxnet/gluon/contrib/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from ...data import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from ....contrib import text
from .... import nd

from .... import nd, base

class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method
def __init__(self, root, namespace, vocabulary):
Expand Down Expand Up @@ -116,7 +115,7 @@ class WikiText2(_WikiText):

Parameters
----------
root : str, default '~/.mxnet/datasets/wikitext-2'
root : str, default $MXNET_HOME/datasets/wikitext-2
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
Expand All @@ -127,7 +126,7 @@ class WikiText2(_WikiText):
The sequence length of each sample, regardless of the sentence boundary.

"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-2'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
self._data_file = {'train': ('wiki.train.tokens',
Expand All @@ -154,7 +153,7 @@ class WikiText103(_WikiText):

Parameters
----------
root : str, default '~/.mxnet/datasets/wikitext-103'
root : str, default $MXNET_HOME/datasets/wikitext-103
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
Expand All @@ -164,7 +163,7 @@ class WikiText103(_WikiText):
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence boundary.
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-103'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
self._data_file = {'train': ('wiki.train.tokens',
Expand Down
18 changes: 9 additions & 9 deletions python/mxnet/gluon/data/vision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .. import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from .... import nd, image, recordio
from .... import nd, image, recordio, base


class MNIST(dataset._DownloadedDataset):
Expand All @@ -40,7 +40,7 @@ class MNIST(dataset._DownloadedDataset):

Parameters
----------
root : str, default '~/.mxnet/datasets/mnist'
root : str, default $MXNET_HOME/datasets/mnist
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -51,7 +51,7 @@ class MNIST(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)

"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
Expand Down Expand Up @@ -101,7 +101,7 @@ class FashionMNIST(MNIST):

Parameters
----------
root : str, default '~/.mxnet/datasets/fashion-mnist'
root : str, default $MXNET_HOME/datasets/fashion-mnist'
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -112,7 +112,7 @@ class FashionMNIST(MNIST):
transform=lambda data, label: (data.astype(np.float32)/255, label)

"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'fashion-mnist'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
Expand All @@ -134,7 +134,7 @@ class CIFAR10(dataset._DownloadedDataset):

Parameters
----------
root : str, default '~/.mxnet/datasets/cifar10'
root : str, default $MXNET_HOME/datasets/cifar10
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -145,7 +145,7 @@ class CIFAR10(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)

"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'),
train=True, transform=None):
self._train = train
self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891')
Expand Down Expand Up @@ -197,7 +197,7 @@ class CIFAR100(CIFAR10):

Parameters
----------
root : str, default '~/.mxnet/datasets/cifar100'
root : str, default $MXNET_HOME/datasets/cifar100
Path to temp folder for storing data.
fine_label : bool, default False
Whether to load the fine-grained (100 classes) or coarse-grained (20 super-classes) labels.
Expand All @@ -210,7 +210,7 @@ class CIFAR100(CIFAR10):
transform=lambda data, label: (data.astype(np.float32)/255, label)

"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar100'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'),
fine_label=False, train=True, transform=None):
self._train = train
self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b')
Expand Down
17 changes: 9 additions & 8 deletions python/mxnet/gluon/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
__all__ = ['get_model_file', 'purge']
import os
import zipfile
import logging

from ..utils import download, check_sha1
from ... import base, util

_model_sha1 = {name: checksum for checksum, name in [
('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
Expand Down Expand Up @@ -68,7 +70,7 @@ def short_hash(name):
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]

def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):
r"""Return location for the pretrained on local file system.

This function will download from online model zoo when model cannot be found or has mismatch.
Expand All @@ -78,7 +80,7 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.

Returns
Expand All @@ -95,12 +97,11 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file detected. Downloading again.')
logging.warning('Mismatch in the content of model file detected. Downloading again.')
else:
print('Model file is not found. Downloading.')
logging.info('Model file not found. Downloading to %s.', file_path)

if not os.path.exists(root):
os.makedirs(root)
util.makedirs(root)

zip_file_path = os.path.join(root, file_name+'.zip')
repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
Expand All @@ -118,12 +119,12 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
else:
raise ValueError('Downloaded file has different hash. Please try again.')

def purge(root=os.path.join('~', '.mxnet', 'models')):
def purge(root=os.path.join(base.data_dir(), 'models')):
r"""Purge all pretrained model files in local file store.

Parameters
----------
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
root = os.path.expanduser(root)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_model(name, **kwargs):
Number of classes for the output layer.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.

Returns
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/gluon/model_zoo/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ....context import cpu
from ...block import HybridBlock
from ... import nn
from .... import base

# Net
class AlexNet(HybridBlock):
Expand Down Expand Up @@ -68,7 +69,7 @@ def hybrid_forward(self, F, x):

# Constructor
def alexnet(pretrained=False, ctx=cpu(),
root=os.path.join('~', '.mxnet', 'models'), **kwargs):
root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

Parameters
Expand All @@ -77,7 +78,7 @@ def alexnet(pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = AlexNet(**kwargs)
Expand Down
Loading