Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Backport Gluon estimator changes to 1.6 #17048

Merged
merged 4 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/mxnet/gluon/contrib/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
"""Gluon Estimator Module"""
from . import estimator
from . import event_handler
from . import batch_processor
from .estimator import *
from .event_handler import *
from .batch_processor import *
105 changes: 105 additions & 0 deletions python/mxnet/gluon/contrib/estimator/batch_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
"""Gluon Batch Processor for Estimators"""

from ...utils import split_and_load
from .... import autograd

__all__ = ['BatchProcessor']

class BatchProcessor(object):
"""BatchProcessor Class for plug and play fit_batch & evaluate_batch

During training or validation, data are divided into minibatches for processing. This
class aims at providing hooks of training or validating on a minibatch of data. Users
may provide customized fit_batch() and evaluate_batch() methods by inheriting from
this class and overriding class methods.

:py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch()
in the base estimator class
"""

def __init__(self):
pass

def _get_data_and_label(self, batch, ctx, batch_axis=0):
data = batch[0]
label = batch[1]
data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
return data, label

def evaluate_batch(self, estimator,
val_batch,
batch_axis=0):
"""Evaluate the estimator model on a batch of validation data.

Parameters
----------
estimator : Estimator
Reference to the estimator
val_batch : tuple
Data and label of a batch from the validation data loader.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
pred = [estimator.eval_net(x) for x in data]
loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]

return data, label, pred, loss

def fit_batch(self, estimator,
train_batch,
batch_axis=0):
"""Trains the estimator model on a batch of training data.

Parameters
----------
estimator : Estimator
Reference to the estimator
train_batch : tuple
Data and label of a batch from the training data loader.
batch_axis : int, default 0
Batch axis to split the training data into devices.

Returns
-------
data: List of NDArray
Sharded data from the batch. Data is sharded with
`gluon.split_and_load`.
label: List of NDArray
Sharded label from the batch. Labels are sharded with
`gluon.split_and_load`.
pred: List of NDArray
Prediction on each of the sharded inputs.
loss: List of NDArray
Loss on each of the sharded inputs.
"""
data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis)

with autograd.record():
pred = [estimator.net(x) for x in data]
loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()

return data, label, pred, loss
Loading