-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix constructor arguments in
ParallelReactionModel
(#96)
One of the more rarely used capabilities of our codebase is running multi-GPU inference during single-step evaluation. Testing this part of code requires having a GPU, and so `ParallelReactionModel` was not covered by unit tests, and it broke at some point during refactoring of the model classes. This PR adds simple tests for it and fixes the issue.
- Loading branch information
Showing
6 changed files
with
53 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
syntheseus/tests/reaction_prediction/utils/test_parallel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
|
||
from syntheseus import Molecule | ||
from syntheseus.tests.cli.test_eval_single_step import DummyModel | ||
|
||
try: | ||
import torch | ||
|
||
from syntheseus.reaction_prediction.utils.parallel import ParallelReactionModel | ||
|
||
torch_available = True | ||
cuda_available = torch.cuda.is_available() | ||
except ModuleNotFoundError: | ||
torch_available = False | ||
cuda_available = False | ||
|
||
|
||
@pytest.mark.skipif( | ||
not torch_available, reason="Simple testing of parallel inference requires torch" | ||
) | ||
def test_parallel_reaction_model_cpu() -> None: | ||
# We cannot really run this on CPU, so just check if the model creation works as normal. | ||
parallel_model: ParallelReactionModel = ParallelReactionModel( | ||
model_fn=DummyModel, devices=["cpu"] * 4 | ||
) | ||
assert parallel_model([]) == [] | ||
|
||
|
||
@pytest.mark.skipif( | ||
not cuda_available, reason="Full testing of parallel inference requires GPU to be available" | ||
) | ||
def test_parallel_reaction_model_gpu() -> None: | ||
model = DummyModel() | ||
parallel_model: ParallelReactionModel = ParallelReactionModel( | ||
model_fn=DummyModel, devices=["cuda:0"] * 4 | ||
) | ||
|
||
inputs = [Molecule("C" * length) for length in range(1, 6)] | ||
assert parallel_model(inputs) == model(inputs) |