Skip to content

Commit

Permalink
refactor: please use get
Browse files Browse the repository at this point in the history
  • Loading branch information
YutackPark committed Nov 6, 2024
1 parent 6ad6afe commit e785944
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
15 changes: 8 additions & 7 deletions sevenn/error_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,17 @@ def init_total_loss_metric(config, criteria):

@staticmethod
def from_config(config: dict):
loss_cls = loss_dict[config[KEY.LOSS].lower()]
try:
loss_param = config[KEY.LOSS_PARAM]
except KeyError:
loss_param = {}
loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()]
loss_param = config.get(KEY.LOSS_PARAM, {})
criteria = loss_cls(**loss_param)

err_config = config[KEY.ERROR_RECORD]
err_config = config.get(KEY.ERROR_RECORD, False)
if not err_config:
raise ValueError(
'No error_record config found. Consider util.get_error_recorder'
)
err_config_n = []
if not config[KEY.IS_TRAIN_STRESS]:
if not config.get(KEY.IS_TRAIN_STRESS, True):
for err_type, metric_name in err_config:
if 'Stress' in err_type:
continue
Expand Down
4 changes: 2 additions & 2 deletions sevenn/scripts/processing_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def processing_epoch_v2(
prefix = f'{os.path.abspath(working_dir)}/'

total_epoch = total_epoch or config[KEY.EPOCH]
per_epoch = per_epoch or config[KEY.PER_EPOCH]
best_metric = best_metric or config[KEY.BEST_METRIC]
per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10)
best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss')
recorder = error_recorder or ErrorRecorder.from_config(config)
recorders = {k: deepcopy(recorder) for k in loaders}

Expand Down
16 changes: 9 additions & 7 deletions sevenn/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer':
trainer = Trainer(
model,
loss_functions=get_loss_functions_from_config(config),
optimizer_cls=optim_dict[config[KEY.OPTIMIZER].lower()],
optimizer_args=config[KEY.OPTIM_PARAM],
scheduler_cls=scheduler_dict[config[KEY.SCHEDULER].lower()],
scheduler_args=config[KEY.SCHEDULER_PARAM],
device=config[KEY.DEVICE],
distributed=config[KEY.IS_DDP],
distributed_backend=config[KEY.DDP_BACKEND]
optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()],
optimizer_args=config.get(KEY.OPTIM_PARAM, {}),
scheduler_cls=scheduler_dict[
config.get(KEY.SCHEDULER, 'exponentiallr').lower()
],
scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}),
device=config.get(KEY.DEVICE, 'auto'),
distributed=config.get(KEY.IS_DDP, 'False'),
distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'),
)
return trainer

Expand Down

0 comments on commit e785944

Please sign in to comment.