diff --git a/pypots/cli/tuning.py b/pypots/cli/tuning.py index 6d6307c8..2b80ad8c 100644 --- a/pypots/cli/tuning.py +++ b/pypots/cli/tuning.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause +import inspect import os from argparse import ArgumentParser, Namespace @@ -237,7 +238,7 @@ def run(self): lr = tuner_params.pop("lr") # check if hyperparameters match - model_all_arguments = model_class.__init__.__annotations__.keys() + model_all_arguments = inspect.signature(model_class).parameters.keys() tuner_params_set = set(tuner_params.keys()) model_arguments_set = set(model_all_arguments) if_hyperparameter_match = tuner_params_set.issubset(model_arguments_set)