Skip to content

Commit

Permalink
pt: make get_data non-blocking (#3422)
Browse files Browse the repository at this point in the history
`to(DEVICE)` is cpu-blocking but `to(DEVICE, non-blocking=True)` is not
blocking. This improves performance by at least 0.1s/100 steps.

Before, `get_data` is blocking:


![1709698811097](https://github.com/deepmodeling/deepmd-kit/assets/9496702/b86b3928-41e7-46d3-8692-ca96b3a6475a)


![1709698811150](https://github.com/deepmodeling/deepmd-kit/assets/9496702/c4365203-3f3d-4de8-aae6-d8587f0e95a0)

After, `get_data` is not blocking:

![1709698811122](https://github.com/deepmodeling/deepmd-kit/assets/9496702/d991c8f0-35c8-4b5d-822e-77af961e9b6e)

![1709698811169](https://github.com/deepmodeling/deepmd-kit/assets/9496702/a56160c2-78c7-4a44-aa96-1df0b520a60a)

The subsequent blocking is `phys2inter` (via `torch.linalg.inv`).

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Mar 8, 2024
1 parent 09bd522 commit 268591c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,9 +973,11 @@ def get_data(self, is_train=True, task_key="Default"):
continue
elif not isinstance(batch_data[key], list):
if batch_data[key] is not None:
batch_data[key] = batch_data[key].to(DEVICE)
batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True)
else:
batch_data[key] = [item.to(DEVICE) for item in batch_data[key]]
batch_data[key] = [
item.to(DEVICE, non_blocking=True) for item in batch_data[key]
]
# we may need a better way to classify which are inputs and which are labels
# now wrapper only supports the following inputs:
input_keys = [
Expand Down

0 comments on commit 268591c

Please sign in to comment.