From 7bbb2e6bf0c275583f1063706a7ef69418bae6ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:11:10 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/entrypoints/main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index a23ae0eda4..71f81b0a12 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -283,7 +283,9 @@ def train( # update init_model or init_frz_model config if necessary if (init_model is not None or init_frz_model is not None) and use_pretrain_script: if init_model is not None: - init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True) + init_state_dict = torch.load( + init_model, map_location=DEVICE, weights_only=True + ) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] @@ -380,7 +382,9 @@ def change_bias( output: Optional[str] = None, ): if input_file.endswith(".pt"): - old_state_dict = torch.load(input_file, map_location=env.DEVICE, weights_only=True) + old_state_dict = torch.load( + input_file, map_location=env.DEVICE, weights_only=True + ) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] elif input_file.endswith(".pth"):