Skip to content

Commit

Permalink
run ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
ajecc committed Feb 1, 2024
1 parent d08d2da commit 814452c
Showing 1 changed file with 28 additions and 62 deletions.
90 changes: 28 additions & 62 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,9 +1462,7 @@ def __init__(self):

self._clearml = clearml
else:
raise RuntimeError(
"ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`."
)
raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")

self._initialized = False
self._clearml_task = None
Expand All @@ -1480,9 +1478,7 @@ def setup(self, args, state, model, tokenizer, **kwargs):
ClearMLCallback._train_run_counter += 1
ClearMLCallback._model_connect_counter += 1
ClearMLCallback.log_suffix = (
""
if ClearMLCallback._train_run_counter == 1
else "_" + str(ClearMLCallback._train_run_counter)
"" if ClearMLCallback._train_run_counter == 1 else "_" + str(ClearMLCallback._train_run_counter)
)
if state.is_world_process_zero:
logger.info("Automatic ClearML logging enabled.")
Expand All @@ -1495,33 +1491,26 @@ def setup(self, args, state, model, tokenizer, **kwargs):

# This might happen when running inside of a pipeline, where the task is already initialized
# from outside of Hugging Face
if (
self._clearml.Task.running_locally()
and self._clearml.Task.current_task()
):
if self._clearml.Task.running_locally() and self._clearml.Task.current_task():
self._clearml_task = self._clearml.Task.current_task()
self._log_model = os.getenv(
"CLEARML_LOG_MODEL",
"FALSE"
if not ClearMLCallback._task_created_in_callback
else "TRUE",
"FALSE" if not ClearMLCallback._task_created_in_callback else "TRUE",
).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
logger.info("External ClearML Task has been connected.")
else:
self._clearml_task = self._clearml.Task.init(
project_name=os.getenv(
"CLEARML_PROJECT", "HuggingFace Transformers"
),
project_name=os.getenv("CLEARML_PROJECT", "HuggingFace Transformers"),
task_name=os.getenv("CLEARML_TASK", "Trainer"),
auto_connect_frameworks={
"tensorboard": False,
"pytorch": False,
},
output_uri=True,
)
self._log_model = os.getenv(
"CLEARML_LOG_MODEL", "TRUE"
).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
{"TRUE"}
)
ClearMLCallback._task_created_in_callback = True
logger.info("ClearML Task has been initialized.")
self._initialized = True
Expand All @@ -1534,29 +1523,22 @@ def setup(self, args, state, model, tokenizer, **kwargs):
)
if self._clearml.Task.running_locally():
self._copy_training_args_as_hparams(
args,
ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
)
self._clearml_task.set_parameter(
name=ignore_hparams_config_section,
value=True,
value_type=bool,
description=(
"If True, ignore hyperparameters overrides done in the UI section" +
"when running remotely. Otherwise, the overrides will be used"
)
)
elif not self._clearml_task.get_parameter(
ignore_hparams_config_section, default=True, cast=True
):
self._clearml_task.connect(
args,
ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
"If True, ignore hyperparameters overrides done in the UI section"
+ "when running remotely. Otherwise, the overrides will be used"
),
)
elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
self._clearml_task.connect(args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix)
else:
self._copy_training_args_as_hparams(
args,
ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
)

if getattr(model, "config", None) is not None:
Expand All @@ -1567,7 +1549,7 @@ def setup(self, args, state, model, tokenizer, **kwargs):
+ ClearMLCallback._ignoge_model_config_overrides
)
configuration_object_description = ClearMLCallback._model_config_description.format(
ClearMLCallback._model_connect_counter
ClearMLCallback._model_connect_counter
)
if ClearMLCallback._model_connect_counter != ClearMLCallback._train_run_counter:
configuration_object_description += " " + ClearMLCallback._model_config_description_note
Expand All @@ -1577,18 +1559,16 @@ def setup(self, args, state, model, tokenizer, **kwargs):
value=True,
value_type=bool,
description=(
"If True, ignore model configuration overrides done in the UI section " +
"when running remotely. Otherwise, the overrides will be used"
)
"If True, ignore model configuration overrides done in the UI section "
+ "when running remotely. Otherwise, the overrides will be used"
),
)
self._clearml_task.set_configuration_object(
name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
config_dict=model.config.to_dict(),
description=configuration_object_description
description=configuration_object_description,
)
elif not self._clearml_task.get_parameter(
ignore_model_config_section, default=True, cast=True
):
elif not self._clearml_task.get_parameter(ignore_model_config_section, default=True, cast=True):
model.config = model.config.from_dict(
self._clearml_task.get_configuration_object_as_dict(
ClearMLCallback._model_config_section + ClearMLCallback.log_suffix
Expand All @@ -1598,12 +1578,10 @@ def setup(self, args, state, model, tokenizer, **kwargs):
self._clearml_task.set_configuration_object(
name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
config_dict=model.config.to_dict(),
description=configuration_object_description
description=configuration_object_description,
)

def on_train_begin(
self, args, state, control, model=None, tokenizer=None, **kwargs
):
def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._clearml is None:
return
self._checkpoints_saved = []
Expand All @@ -1617,9 +1595,7 @@ def on_train_end(self, args, state, control, **kwargs):
self._clearml_task.close()
ClearMLCallback._train_run_counter = 0

def on_log(
self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs
):
def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
if self._clearml is None:
return
if not self._initialized:
Expand Down Expand Up @@ -1676,9 +1652,7 @@ def on_save(self, args, state, control, **kwargs):
if self._log_model and self._clearml_task and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(
f"Logging checkpoint artifacts in {ckpt_dir}. This may take time."
)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
name = ckpt_dir + ClearMLCallback.log_suffix
output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
output_model.connect(task=self._clearml_task, name=name)
Expand All @@ -1689,9 +1663,7 @@ def on_save(self, args, state, control, **kwargs):
auto_delete_file=False,
)
self._checkpoints_saved.append(output_model)
while args.save_total_limit and args.save_total_limit < len(
self._checkpoints_saved
):
while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
try:
self._clearml.model.Model.remove(
self._checkpoints_saved[0],
Expand All @@ -1702,8 +1674,7 @@ def on_save(self, args, state, control, **kwargs):
except Exception as e:
logger.warning(
"Could not remove checkpoint `{}` after going over the `save_total_limit`. Error is: {}".format(
self._checkpoints_saved[0].name,
e
self._checkpoints_saved[0].name, e
)
)
break
Expand All @@ -1714,12 +1685,7 @@ def _copy_training_args_as_hparams(self, training_args, prefix):
token_keys = [k for k in as_dict.keys() if k.endswith("_token")]
for token_key in token_keys:
as_dict.pop(token_key, None)
flat_dict = {
str(k): v
for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(
as_dict
).items()
}
flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)


Expand Down

0 comments on commit 814452c

Please sign in to comment.