Skip to content

Commit

Permalink
Fix constructor arguments in ParallelReactionModel (#96)
Browse files Browse the repository at this point in the history
One of the more rarely used capabilities of our codebase is running
multi-GPU inference during single-step evaluation. Testing this part of
code requires having a GPU, and so `ParallelReactionModel` was not
covered by unit tests, and it broke at some point during refactoring of
the model classes. This PR adds simple tests for it and fixes the issue.
  • Loading branch information
kmaziarz authored Aug 13, 2024
1 parent feb813d commit 38f475f
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ jobs:
run: |
coverage run -p -m pytest \
./syntheseus/tests/cli/test_cli.py \
./syntheseus/tests/reaction_prediction/inference/test_models.py
./syntheseus/tests/reaction_prediction/inference/test_models.py \
./syntheseus/tests/reaction_prediction/utils/test_parallel.py
coverage report --data-file .coverage.*
- name: Upload coverage report
uses: actions/upload-artifact@v4
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
### Fixed

- Shift the `pandas` dependency to the external model packages ([#94](https://github.com/microsoft/syntheseus/pull/94)) ([@kmaziarz])
- Fix constructor arguments in `ParallelReactionModel` ([#96](https://github.com/microsoft/syntheseus/pull/96)) ([@kmaziarz])

## [0.4.1] - 2024-05-04

Expand Down
4 changes: 3 additions & 1 deletion syntheseus/reaction_prediction/utils/model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ def model_fn(device):
except ModuleNotFoundError:
raise ValueError("Multi-GPU evaluation is only supported for torch-based models")

return ParallelReactionModel(model_fn, devices=[f"cuda:{idx}" for idx in range(num_gpus)])
return ParallelReactionModel(
model_fn=model_fn, devices=[f"cuda:{idx}" for idx in range(num_gpus)]
)
4 changes: 3 additions & 1 deletion syntheseus/reaction_prediction/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class ParallelReactionModel(ReactionModel[InputType, ReactionType]):
appropriately), whereas other approaches usually only work with tensors.
"""

def __init__(self, model_fn: Callable, devices: List) -> None:
def __init__(self, *args, model_fn: Callable, devices: List, **kwargs) -> None:
super().__init__(*args, **kwargs)

self._devices = devices
self._model_replicas = [model_fn(device=device) for device in devices]

Expand Down
6 changes: 5 additions & 1 deletion syntheseus/tests/cli/test_eval_single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@


class DummyModel(ReactionModel):
def __init__(self, is_forward: bool, repeat: bool, **kwargs) -> None:
def __init__(
self, device: str = "cpu", is_forward: bool = False, repeat: bool = False, **kwargs
) -> None:
super().__init__(**kwargs)

self.device = device
self._is_forward = is_forward
self._repeat = repeat

Expand Down
39 changes: 39 additions & 0 deletions syntheseus/tests/reaction_prediction/utils/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from syntheseus import Molecule
from syntheseus.tests.cli.test_eval_single_step import DummyModel

try:
import torch

from syntheseus.reaction_prediction.utils.parallel import ParallelReactionModel

torch_available = True
cuda_available = torch.cuda.is_available()
except ModuleNotFoundError:
torch_available = False
cuda_available = False


@pytest.mark.skipif(
not torch_available, reason="Simple testing of parallel inference requires torch"
)
def test_parallel_reaction_model_cpu() -> None:
# We cannot really run this on CPU, so just check if the model creation works as normal.
parallel_model: ParallelReactionModel = ParallelReactionModel(
model_fn=DummyModel, devices=["cpu"] * 4
)
assert parallel_model([]) == []


@pytest.mark.skipif(
not cuda_available, reason="Full testing of parallel inference requires GPU to be available"
)
def test_parallel_reaction_model_gpu() -> None:
model = DummyModel()
parallel_model: ParallelReactionModel = ParallelReactionModel(
model_fn=DummyModel, devices=["cuda:0"] * 4
)

inputs = [Molecule("C" * length) for length in range(1, 6)]
assert parallel_model(inputs) == model(inputs)

0 comments on commit 38f475f

Please sign in to comment.