Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 23, 2024
1 parent 43da6fc commit 8508e2f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
7 changes: 2 additions & 5 deletions src/fairchem/core/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,29 @@ class PGConfig:
use_gp: bool = True


def _init_env_rank_and_launch_test(
def init_env_rank_and_launch_test(
rank: int,
pg_setup_params: PGConfig,
mp_output_dict: dict[int, object],
test_method: callable,
args: list[object],
kwargs: dict[str, object],
init_process_group: bool = False,
) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = pg_setup_params.port
os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["RANK"] = str(rank)

mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme


def _init_pg_and_rank_and_launch_test(
def init_pg_and_rank_and_launch_test(
rank: int,
pg_setup_params: PGConfig,
mp_output_dict: dict[int, object],
test_method: callable,
args: list[object],
kwargs: dict[str, object],
init_process_group: bool = False,
) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = pg_setup_params.port
Expand Down
6 changes: 3 additions & 3 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import yaml
from fairchem.core.common.test_utils import (
PGConfig,
_init_env_rank_and_launch_test,
init_env_rank_and_launch_test,
spawn_multi_process,
)
from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes
Expand Down Expand Up @@ -121,12 +121,12 @@ def _run_main(

if world_size > 0:
pg_config = PGConfig(
backend="gloo", world_size=2, gp_group_size=1, use_gp=False
backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False
)
spawn_multi_process(
pg_config,
Runner(distributed=True),
_init_env_rank_and_launch_test,
init_env_rank_and_launch_test,
config,
)
else:
Expand Down
14 changes: 11 additions & 3 deletions tests/core/models/test_equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from torch.nn.parallel.distributed import DistributedDataParallel

from fairchem.core.common.registry import registry
from fairchem.core.common.test_utils import PGConfig, spawn_multi_process
from fairchem.core.common.test_utils import (
PGConfig,
init_pg_and_rank_and_launch_test,
spawn_multi_process,
)
from fairchem.core.common.utils import load_state_dict, setup_imports
from fairchem.core.datasets import data_list_collater
from fairchem.core.models.equiformer_v2.so3 import (
Expand Down Expand Up @@ -140,7 +144,9 @@ def test_energy_force_shape(self, snapshot):
def test_ddp(self, snapshot):
data_dist = self.data.clone().detach()
config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False)
output = spawn_multi_process(config, _runner, data_dist)
output = spawn_multi_process(
config, _runner, init_pg_and_rank_and_launch_test, data_dist
)
assert len(output) == 1
energy, forces = output[0]["energy"], output[0]["forces"]
assert snapshot == energy.shape
Expand All @@ -151,7 +157,9 @@ def test_ddp(self, snapshot):
def test_gp(self, snapshot):
data_dist = self.data.clone().detach()
config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True)
output = spawn_multi_process(config, _runner, data_dist)
output = spawn_multi_process(
config, _runner, init_pg_and_rank_and_launch_test, data_dist
)
assert len(output) == 2
energy, forces = output[0]["energy"], output[0]["forces"]
assert snapshot == energy.shape
Expand Down

0 comments on commit 8508e2f

Please sign in to comment.