diff --git a/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py b/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py index 7359beb3a7..5864d37931 100644 --- a/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py +++ b/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py @@ -317,7 +317,7 @@ def from_inputs(cls, train_iter, **params): if field in params.keys(): is_params_field = ( params[field] - if type(params[field]) == bool + if isinstance(params[field], bool) else strtobool(params[field]) ) if is_params_field and not auto_params[field]: