Skip to content

Commit

Permalink
Now one can use getMPISize and getMPIRank before init().
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Apr 1, 2024
1 parent fc52355 commit ea43f32
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
21 changes: 9 additions & 12 deletions pyquda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class _ComputeCapability(NamedTuple):
minor: int


_MPI_COMM: MPI.Comm = None
_MPI_SIZE: int = 1
_MPI_RANK: int = 0
_GRID_SIZE: List[int] = [1, 1, 1, 1]
_MPI_COMM: MPI.Comm = MPI.COMM_WORLD
_MPI_SIZE: int = _MPI_COMM.Get_size()
_MPI_RANK: int = _MPI_COMM.Get_rank()
_GRID_SIZE: List[int] = None
_GRID_COORD: List[int] = [0, 0, 0, 0]
_DEFAULT_LATTICE: LatticeInfo = None
_CUDA_BACKEND: Literal["numpy", "cupy", "torch"] = "cupy"
Expand Down Expand Up @@ -123,18 +123,15 @@ def init(
Initialize MPI along with the QUDA library.
"""
filterwarnings("default", "", DeprecationWarning)
global _MPI_COMM, _MPI_SIZE, _MPI_RANK, _GRID_SIZE, _GRID_COORD
if _MPI_COMM is None:
global _GRID_SIZE, _GRID_COORD
if _GRID_SIZE is None:
import atexit
from platform import node as gethostname

Gx, Gy, Gz, Gt = grid_size if grid_size is not None else [1, 1, 1, 1]
_MPI_COMM = MPI.COMM_WORLD
_MPI_SIZE = _MPI_COMM.Get_size()
_MPI_RANK = _MPI_COMM.Get_rank()
assert _MPI_SIZE == Gx * Gy * Gz * Gt
_GRID_SIZE = [Gx, Gy, Gz, Gt]
_GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE)
assert _MPI_SIZE == Gx * Gy * Gz * Gt
printRoot(f"INFO: Using gird {_GRID_SIZE}")

_initEnvironWarn(resource_path=resource_path if resource_path != "" else None)
Expand Down Expand Up @@ -262,8 +259,8 @@ def getDefaultLattice():


def setGPUID(gpuid: int):
global _MPI_COMM, _GPUID
assert _MPI_COMM is None, "setGPUID() should be called before init()"
global _GPUID
assert _GRID_SIZE is None, "setGPUID() should be called before init()"
assert gpuid >= 0
_GPUID = gpuid

Expand Down
2 changes: 1 addition & 1 deletion pyquda/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.14"
__version__ = "0.6.15"

0 comments on commit ea43f32

Please sign in to comment.