diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index a65a325306..316e3a1893 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -114,9 +114,12 @@ def download_model(self, model_name): # download from gdrive self._download_gdrive_file(model_item['model_file'], output_model_path) self._download_gdrive_file(model_item['config_file'], output_config_path) - if self._check_dict_key(model_item, 'stats_file'): + if self._check_dict_key(model_item, 'stats_file'): + self._download_gdrive_file(model_item['stats_file'], output_stats_path) + + # set the scale_path.npy file path in the model config.json + if self._check_dict_key(model_item, 'stats_file') or os.path.exists(os.path.join(output_path, 'scale_stats.npy')): output_stats_path = os.path.join(output_path, 'scale_stats.npy') - self._download_gdrive_file(model_item['stats_file'], output_stats_path) # set scale stats path in config.json config_path = output_config_path config = load_config(config_path)