diff --git a/deepspeed/pt/deepspeed_config.py b/deepspeed/pt/deepspeed_config.py old mode 100644 new mode 100755 index b520d322bd5e..460c72399c6b --- a/deepspeed/pt/deepspeed_config.py +++ b/deepspeed/pt/deepspeed_config.py @@ -256,7 +256,8 @@ def _initialize_params(self, param_dict): self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) self.optimizer_name = get_optimizer_name(param_dict) - if self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: + if self.optimizer_name is not None and \ + self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: self.optimizer_name = self.optimizer_name.lower() self.optimizer_params = get_optimizer_params(param_dict) diff --git a/tests/unit/test_ds_config.py b/tests/unit/test_ds_config.py new file mode 100755 index 000000000000..8e49748710c2 --- /dev/null +++ b/tests/unit/test_ds_config.py @@ -0,0 +1,19 @@ +import pytest +import os +import json +from deepspeed.pt import deepspeed_config as ds_config + + +def test_only_required_fields(tmpdir): + '''Ensure that config containing only the required fields is accepted. ''' + cfg_json = tmpdir.mkdir('ds_config_unit_test').join('minimal.json') + + with open(cfg_json, 'w') as f: + required_fields = {'train_batch_size': 64} + json.dump(required_fields, f) + + run_cfg = ds_config.DeepSpeedConfig(cfg_json) + assert run_cfg is not None + assert run_cfg.train_batch_size == 64 + assert run_cfg.train_micro_batch_size_per_gpu == 64 + assert run_cfg.gradient_accumulation_steps == 1