Skip to content

Commit

Permalink
tidy code based on code review
Browse files Browse the repository at this point in the history
  • Loading branch information
ajecc committed Feb 1, 2024
1 parent 814452c commit 5276282
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,10 +1502,7 @@ def setup(self, args, state, model, tokenizer, **kwargs):
self._clearml_task = self._clearml.Task.init(
project_name=os.getenv("CLEARML_PROJECT", "HuggingFace Transformers"),
task_name=os.getenv("CLEARML_TASK", "Trainer"),
auto_connect_frameworks={
"tensorboard": False,
"pytorch": False,
},
auto_connect_frameworks={"tensorboard": False, "pytorch": False},
output_uri=True,
)
self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
Expand All @@ -1515,38 +1512,29 @@ def setup(self, args, state, model, tokenizer, **kwargs):
logger.info("ClearML Task has been initialized.")
self._initialized = True

ignore_hparams_config_section = (
ClearMLCallback._hparams_section
+ ClearMLCallback.log_suffix
+ "/"
+ ClearMLCallback._ignore_hparams_overrides
)
suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides
if self._clearml.Task.running_locally():
self._copy_training_args_as_hparams(
args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
)
self._copy_training_args_as_hparams(args, suffixed_hparams_section)
self._clearml_task.set_parameter(
name=ignore_hparams_config_section,
value=True,
value_type=bool,
description=(
"If True, ignore hyperparameters overrides done in the UI section"
+ "when running remotely. Otherwise, the overrides will be used"
"If True, ignore Transformers hyperparameters overrides done in the UI/backend "
+ "when running remotely. Otherwise, the overrides will be applied when running remotely"
),
)
elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
self._clearml_task.connect(args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix)
self._clearml_task.connect(args, suffixed_hparams_section)
else:
self._copy_training_args_as_hparams(
args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
)

if getattr(model, "config", None) is not None:
ignore_model_config_section = (
ClearMLCallback._hparams_section
+ ClearMLCallback.log_suffix
+ "/"
+ ClearMLCallback._ignoge_model_config_overrides
suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides
)
configuration_object_description = ClearMLCallback._model_config_description.format(
ClearMLCallback._model_connect_counter
Expand All @@ -1559,8 +1547,8 @@ def setup(self, args, state, model, tokenizer, **kwargs):
value=True,
value_type=bool,
description=(
"If True, ignore model configuration overrides done in the UI section "
+ "when running remotely. Otherwise, the overrides will be used"
"If True, ignore Transformers model configuration overrides done in the UI/backend "
+ "when running remotely. Otherwise, the overrides will be applied when running remotely"
),
)
self._clearml_task.set_configuration_object(
Expand Down Expand Up @@ -1652,8 +1640,8 @@ def on_save(self, args, state, control, **kwargs):
if self._log_model and self._clearml_task and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
name = ckpt_dir + ClearMLCallback.log_suffix
logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
output_model.connect(task=self._clearml_task, name=name)
output_model.update_weights_package(
Expand Down Expand Up @@ -1681,10 +1669,11 @@ def on_save(self, args, state, control, **kwargs):
self._checkpoints_saved = self._checkpoints_saved[1:]

def _copy_training_args_as_hparams(self, training_args, prefix):
as_dict = {field.name: getattr(training_args, field.name) for field in fields(training_args) if field.init}
token_keys = [k for k in as_dict.keys() if k.endswith("_token")]
for token_key in token_keys:
as_dict.pop(token_key, None)
as_dict = {
field.name: getattr(training_args, field.name)
for field in fields(training_args)
if field.init and not field.name.endswith("_token")
}
flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)

Expand Down

0 comments on commit 5276282

Please sign in to comment.