Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed mxContext from core. #977

Merged
merged 1 commit into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 0 additions & 69 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,75 +448,6 @@ def init_wrapper(*args, **kwargs):
return validator


class MXContext:
"""
Defines `custom data type validation
<https://pydantic-docs.helpmanual.io/#custom-data-types>`_ for
the :class:`~mxnet.context.Context` data type.
"""

@classmethod
def validate(cls, v: Union[str, mx.Context]) -> mx.Context:
if isinstance(v, mx.Context):
return v

m = re.search(r"^(?P<dev_type>cpu|gpu)(\((?P<dev_id>\d+)\))?$", v)

if m:
return mx.Context(m["dev_type"], int(m["dev_id"] or 0))
else:
raise ValueError(
f"bad MXNet context {v}, expected either an "
f"mx.context.Context or its string representation"
)

@classmethod
def __get_validators__(cls) -> mx.Context:
yield cls.validate


mx.Context.validate = MXContext.validate
mx.Context.__get_validators__ = MXContext.__get_validators__


NUM_GPUS = None


def num_gpus(refresh=False):
global NUM_GPUS
if NUM_GPUS is None or refresh:
n = 0
try:
n = mx.context.num_gpus()
except mx.base.MXNetError as e:
logger.error(f"Failure when querying GPU: {e}")
NUM_GPUS = n
return NUM_GPUS


@functools.lru_cache()
def get_mxnet_context(gpu_number=0) -> mx.Context:
"""
Returns either CPU or GPU context
"""
if num_gpus():
logger.info("Using GPU")
return mx.context.gpu(gpu_number)
else:
logger.info("Using CPU")
return mx.context.cpu()


def check_gpu_support() -> bool:
"""
Emits a log line and returns a boolean that indicate whether
the currently installed MXNet version has GPU support.
"""
n = num_gpus()
logger.info(f'MXNet GPU support is {"ON" if n > 0 else "OFF"}')
return n != 0


class DType:
"""
Defines `custom data type validation
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@
DType,
equals,
from_hyperparameters,
get_mxnet_context,
validated,
)
from gluonts.core.exception import GluonTSException
from gluonts.core.serde import dump_json, fqname_for, load_json
from gluonts.dataset.common import DataEntry, Dataset, ListDataset
from gluonts.dataset.loader import DataBatch, InferenceDataLoader
from gluonts.model.forecast import Forecast
from gluonts.mx.context import get_mxnet_context
from gluonts.mx.distribution import Distribution, DistributionOutput
from gluonts.support.util import (
export_repr_block,
Expand Down
90 changes: 90 additions & 0 deletions src/gluonts/mx/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import functools
import logging
import re
from typing import Union

import mxnet as mx

logger = logging.getLogger(__name__)


class MXContext:
"""
Defines `custom data type validation
<https://pydantic-docs.helpmanual.io/#custom-data-types>`_ for
the :class:`~mxnet.context.Context` data type.
"""

@classmethod
def validate(cls, v: Union[str, mx.Context]) -> mx.Context:
if isinstance(v, mx.Context):
return v

m = re.search(r"^(?P<dev_type>cpu|gpu)(\((?P<dev_id>\d+)\))?$", v)

if m:
return mx.Context(m["dev_type"], int(m["dev_id"] or 0))
else:
raise ValueError(
f"bad MXNet context {v}, expected either an "
f"mx.context.Context or its string representation"
)

@classmethod
def __get_validators__(cls) -> mx.Context:
yield cls.validate


mx.Context.validate = MXContext.validate
mx.Context.__get_validators__ = MXContext.__get_validators__


NUM_GPUS = None


def num_gpus(refresh=False):
global NUM_GPUS
if NUM_GPUS is None or refresh:
n = 0
try:
n = mx.context.num_gpus()
except mx.base.MXNetError as e:
logger.error(f"Failure when querying GPU: {e}")
NUM_GPUS = n
return NUM_GPUS


@functools.lru_cache()
def get_mxnet_context(gpu_number=0) -> mx.Context:
"""
Returns either CPU or GPU context
"""
if num_gpus():
logger.info("Using GPU")
return mx.context.gpu(gpu_number)
else:
logger.info("Using CPU")
return mx.context.cpu()


def check_gpu_support() -> bool:
"""
Emits a log line and returns a boolean that indicate whether
the currently installed MXNet version has GPU support.
"""
n = num_gpus()
logger.info(f'MXNet GPU support is {"ON" if n > 0 else "OFF"}')
return n != 0
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/custom_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import numpy as np

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .binning_helpers import bin_edges_from_bin_centers
from .representation import Representation
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/dim_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import numpy as np

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .representation import Representation

Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/discrete_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from mxnet.gluon import nn

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .representation import Representation

Expand Down
5 changes: 3 additions & 2 deletions src/gluonts/mx/representation/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from mxnet.gluon import nn

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .representation import Representation

Expand All @@ -30,7 +31,7 @@ class Embedding(Representation):
"""
A class representing an embedding operation on top of a given binning.
Note that this representation is intended to applied on top of categorical/binned data.

Parameters
----------
num_bins
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/global_relative_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import numpy as np

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .binning_helpers import (
bin_edges_from_bin_centers,
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/hybrid_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import numpy as np

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .representation import Representation

Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from mxnet.gluon import nn

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context


class Representation(nn.HybridBlock):
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/representation/representation_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import numpy as np

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.common import Tensor
from gluonts.mx.context import get_mxnet_context

from .representation import Representation

Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/mx/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import numpy as np

# First-party imports
from gluonts.core.component import get_mxnet_context, validated
from gluonts.core.component import validated
from gluonts.core.exception import GluonTSDataError, GluonTSUserError
from gluonts.dataset.loader import TrainDataLoader, ValidationDataLoader
from gluonts.gluonts_tqdm import tqdm
from gluonts.mx.context import get_mxnet_context
from gluonts.support.util import HybridContext

# Relative imports
Expand Down
3 changes: 0 additions & 3 deletions src/gluonts/shell/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
# First-party imports
import gluonts
from gluonts.core import fqname_for
from gluonts.core.component import check_gpu_support
from gluonts.model.estimator import Estimator
from gluonts.model.predictor import Predictor
from gluonts.shell.sagemaker import ServeEnv
Expand Down Expand Up @@ -120,8 +119,6 @@ def make_gunicorn_app(
forecaster_type: Optional[Type[Union[Estimator, Predictor]]],
settings: Settings,
) -> Application:
check_gpu_support()

if forecaster_type is not None:
logger.info(f"Using dynamic predictor factory")

Expand Down
3 changes: 0 additions & 3 deletions src/gluonts/shell/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# First-party imports
import gluonts
from gluonts.core import fqname_for
from gluonts.core.component import check_gpu_support
from gluonts.core.serde import dump_code
from gluonts.dataset.common import Dataset
from gluonts.evaluation import Evaluator, backtest
Expand Down Expand Up @@ -57,8 +56,6 @@ def log_metric(metric: str, value: Any) -> None:
def run_train_and_test(
env: TrainEnv, forecaster_type: Type[Union[Estimator, Predictor]]
) -> None:
check_gpu_support()

# train_stats = calculate_dataset_statistics(env.datasets["train"])
# log_metric("train_dataset_stats", train_stats)

Expand Down
2 changes: 1 addition & 1 deletion test/support/test_jitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys

# First-party imports
from gluonts.core.component import check_gpu_support
from gluonts.mx.context import check_gpu_support
from gluonts.mx.kernels import RBFKernel
from gluonts.model.gp_forecaster.gaussian_process import GaussianProcess
from gluonts.support.linalg_util import jitter_cholesky, jitter_cholesky_eig
Expand Down