-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Plug and Play fit_batch() for the estimator class #16930
Comments
A common difficulty in using callback is the confusion about signature. You can work around it by defining a class, as is done in the event handler. |
Why don't we just pack the input and unpack it when we feed it into the model? The order of input will be determined by the user, i.e, how they construct the dataloader. The RNN model specific input |
Ok, we will wrap it in the class. |
The same considerations should apply to |
I think so |
The problem is even we can use pack/unpack workaround to bypass the input/output issue. The computation of loss function may still be diverse. For example, for
On the other hand, for
In this case, we may resort to if/else clause to judge which computation routine we shall use. The code will be cumbersome in the future. |
Description
In the current estimator implementation, fit_batch() is a class method of the estimator class. A common workflow of fit_batch() is that the model
self.net
forwards the training batch to generate outputs and compute loss functions. The problem is that such design is not flexible enough with different model forward interfaces on the same task. For example, fit_batch() of the base estimator trains the current batch on the label prediction task:In the above example, the model forward interface of
self.net
isdef forward(self, inputs)
with the return value of predict labels. The estimator is compatible with any model using this forward interface. However, if we have another model for the label prediction task with a different forward interfacedef forward(self, inputs, input_length)
, the base estimator is not compatible with this model even though both models share the same loss functions, training and evaluation metrics. A real world example can be found at LM models (https://github.com/dmlc/gluon-nlp/blob/c03665bafb1e0fe0fa5c2a59bbb4845393fbf9ba/src/gluonnlp/model/train/language_model.py).StandardRNN
andAWDRNN
shares the same forward interface, whereasBigRNN
has a different one.Forward interface of
StandardRNN
andAWDRNN
:Forward interface of
BigRNN
:A straightforward workaround is to create a new customized estimator for each model interface. It will bring the issue that we need to create a standalone estimator each time we see a new model interface even on the same task. In machine learning community, it is common to see different model forward logic on the same task. This approach will leads to prohibitively many estimators for some simple task. In the above LM example, we need to create a
vanillaRNNEstimator
and aBigRNNEstimator
even most of the training logic between these two estimators are the same.To prevent the above estimator explosion issue, we suggest adding support of a plug and play customized
fit_batch()
which is similar to theevent_handlers
for the estimator class. Given an existing estimatorest
, we modify thefit()
method to take an extra argument offit_batch_handler
. So we can callest.fit(train_data=data_loader, epochs=epochs, fit_batch_handler=fit_StandardRNN_batch)
orest.fit(train_data=data_loader, epochs=epochs, fit_batch_handler=fit_BigRNN_batch)
to use models with different interface.If there is no
fit_batch_handler
provided, we will use the defaultfit_batch()
method.The text was updated successfully, but these errors were encountered: