Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore(pt): fix warning in test_training #4245

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,23 @@
f"The profiling trace have been saved to: {self.profiling_file}"
)

def delete_dataloader(self):
if self.multi_task:
for model_key in self.model_keys:
del (

Check warning on line 1043 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1042-L1043

Added lines #L1042 - L1043 were not covered by tests
self.training_data[model_key],
self.training_dataloader[model_key],
self.validation_data[model_key],
self.validation_dataloader[model_key],
)
else:
del (
self.training_data,
self.training_dataloader,
self.validation_data,
self.validation_dataloader,
)
Comment on lines +1050 to +1055

Check warning

Code scanning / CodeQL

Unnecessary delete statement in function Warning

Unnecessary deletion of local variable
Tuple
in function
delete_dataloader
.

def save_model(self, save_path, lr=0.0, step=0):
module = (
self.wrapper.module
Expand Down
20 changes: 17 additions & 3 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
Expand Down Expand Up @@ -34,6 +35,7 @@ def test_dp_train(self):
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
trainer.run()
trainer.delete_dataloader()
state_dict_trained = trainer.wrapper.model.state_dict()

# test fine-tuning using same input
Expand Down Expand Up @@ -100,6 +102,11 @@ def test_dp_train(self):
trainer_finetune_empty.run()
trainer_finetune_random.run()

# delete dataloader to stop buffer fetching
trainer_finetune.delete_dataloader()
trainer_finetune_empty.delete_dataloader()
trainer_finetune_random.delete_dataloader()

def test_trainable(self):
fix_params = deepcopy(self.config)
fix_params["model"]["descriptor"]["trainable"] = False
Expand Down Expand Up @@ -195,18 +202,25 @@ def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.original_data_path = Path(__file__).parent / "water/data/data_0"
# Create a temporary directory for this test
self.temp_dir = Path(tempfile.mkdtemp())
self.temp_data_path = self.temp_dir / "data_0"
shutil.copytree(self.original_data_path, self.temp_data_path)

data_file = [str(self.temp_data_path)]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["model"]["fitting_net"]["numb_fparam"] = 1
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
self.set_path = self.temp_data_path / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
# Remove the temporary directory and all its contents
shutil.rmtree(self.temp_dir)
DPTrainTest.tearDown(self)


Expand Down
Loading