Skip to content

Commit

Permalink
update check for mxnet
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed May 31, 2024
1 parent 5cee03c commit 7a67505
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:


def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
class_name = prediction_net.__class__.__module__
if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"):
return prediction_net(*inputs.values())
else:
return prediction_net(**inputs)
try:
# Feed inputs as positional arguments for MXNet block predictors
import mxnet as mx

if isinstance(prediction_net, mx.gluon.HybridBlock):
return prediction_net(*inputs.values())
except ImportError:
pass
return prediction_net(**inputs)


class ForecastGenerator:
Expand Down

0 comments on commit 7a67505

Please sign in to comment.