Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Plug and Play fit_batch() for the estimator class #16930

Closed
liuzh47 opened this issue Nov 28, 2019 · 8 comments
Closed

Plug and Play fit_batch() for the estimator class #16930

liuzh47 opened this issue Nov 28, 2019 · 8 comments

Comments

@liuzh47
Copy link
Contributor

liuzh47 commented Nov 28, 2019

Description

In the current estimator implementation, fit_batch() is a class method of the estimator class. A common workflow of fit_batch() is that the model self.net forwards the training batch to generate outputs and compute loss functions. The problem is that such design is not flexible enough with different model forward interfaces on the same task. For example, fit_batch() of the base estimator trains the current batch on the label prediction task:

        with autograd.record():
            pred = [self.net(x) for x in data]
            loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]

In the above example, the model forward interface of self.net is def forward(self, inputs) with the return value of predict labels. The estimator is compatible with any model using this forward interface. However, if we have another model for the label prediction task with a different forward interface def forward(self, inputs, input_length), the base estimator is not compatible with this model even though both models share the same loss functions, training and evaluation metrics. A real world example can be found at LM models (https://github.com/dmlc/gluon-nlp/blob/c03665bafb1e0fe0fa5c2a59bbb4845393fbf9ba/src/gluonnlp/model/train/language_model.py). StandardRNN and AWDRNN shares the same forward interface, whereas BigRNN has a different one.

Forward interface of StandardRNN and AWDRNN:

def __call__(self, inputs, begin_state=None):

Forward interface of BigRNN:

def forward(self, inputs, label, begin_state, sampled_values):

A straightforward workaround is to create a new customized estimator for each model interface. It will bring the issue that we need to create a standalone estimator each time we see a new model interface even on the same task. In machine learning community, it is common to see different model forward logic on the same task. This approach will leads to prohibitively many estimators for some simple task. In the above LM example, we need to create a vanillaRNNEstimator and a BigRNNEstimator even most of the training logic between these two estimators are the same.

To prevent the above estimator explosion issue, we suggest adding support of a plug and play customized fit_batch() which is similar to the event_handlers for the estimator class. Given an existing estimator est, we modify the fit() method to take an extra argument of fit_batch_handler. So we can call est.fit(train_data=data_loader, epochs=epochs, fit_batch_handler=fit_StandardRNN_batch) or est.fit(train_data=data_loader, epochs=epochs, fit_batch_handler=fit_BigRNN_batch) to use models with different interface.
If there is no fit_batch_handler provided, we will use the default fit_batch() method.

@leezu
Copy link
Contributor

leezu commented Nov 28, 2019

CC @sxjscience @szha @roywei

@liuzh47
Copy link
Contributor Author

liuzh47 commented Nov 28, 2019

@szhengac

@szha
Copy link
Member

szha commented Nov 28, 2019

A common difficulty in using callback is the confusion about signature. You can work around it by defining a class, as is done in the event handler.

@szhengac
Copy link
Contributor

Why don't we just pack the input and unpack it when we feed it into the model? The order of input will be determined by the user, i.e, how they construct the dataloader. The RNN model specific input begin_state will always serve as a keyword argument.

@liuzh47
Copy link
Contributor Author

liuzh47 commented Nov 28, 2019

A common difficulty in using callback is the confusion about signature. You can work around it by defining a class, as is done in the event handler.

Ok, we will wrap it in the class.

@leezu
Copy link
Contributor

leezu commented Nov 28, 2019

The same considerations should apply to evaluate_batch?

@liuzh47
Copy link
Contributor Author

liuzh47 commented Nov 28, 2019

The same considerations should apply to evaluate_batch?

I think so

@liuzh47
Copy link
Contributor Author

liuzh47 commented Nov 28, 2019

Why don't we just pack the input and unpack it when we feed it into the model? The order of input will be determined by the user, i.e, how they construct the dataloader. The RNN model specific input begin_state will always serve as a keyword argument.

The problem is even we can use pack/unpack workaround to bypass the input/output issue. The computation of loss function may still be diverse. For example, for standardRNN, one has:

      output, h, encoder_hs, dropped_encoder_hs = model(X, h)
      l = joint_loss(output, y, encoder_hs, dropped_encoder_hs)
      Ls.append(l / (len(context) * X.size))
      hiddens[j] = h

On the other hand, for BigRNN, one has

      output, hidden, new_target = self._model(X, y, h, s)
      output = output.reshape((-3, -1))
      new_target = new_target.reshape((-1,))
      ls = self._loss(output, new_target) * m.reshape((-1,))
      ls = ls / args.batch_size

In this case, we may resort to if/else clause to judge which computation routine we shall use. The code will be cumbersome in the future.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

4 participants