Skip to content

Commit

Permalink
Merge pull request #95 from pyiron/extract_variables
Browse files Browse the repository at this point in the history
extract_vairable()
  • Loading branch information
jan-janssen authored Feb 22, 2023
2 parents 170b183 + 6a339c1 commit 467d2f2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 19 deletions.
44 changes: 27 additions & 17 deletions pylammpsmpi/mpi/lmpmpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,10 @@ def convert_data(val, type, length, width):
val = job.extract_compute(*filtered_args)
return convert_data(val=val, type=type, length=length, width=width)
elif style == 1: # per atom property
val = job.numpy.extract_compute(*filtered_args)
val_gather = MPI.COMM_WORLD.gather(val, root=0)
val = _gather_data_from_all_processors(
data=job.numpy.extract_compute(*filtered_args)
)
if MPI.COMM_WORLD.rank == 0:
# val_gather.shape [number of cores, atoms on specific core]
# the number of atoms on specific cores can vary
val = []
for vl in val_gather:
for v in vl:
val.append(v)
length = job.get_natoms()
return convert_data(val=val, type=type, length=length, width=width)
else: # Todo
Expand Down Expand Up @@ -165,15 +160,20 @@ def extract_fix(funct_args):
def extract_variable(funct_args):
# in the args - if the third one,
# which is the type is 1 - a lammps array is returned
if MPI.COMM_WORLD.rank == 0:
# if type is 1 - reformat file
try:
data = job.extract_variable(*funct_args)
except ValueError:
return []
if funct_args[2] == 1:
data = np.array(data)
return data
if funct_args[2] == 1:
data = _gather_data_from_all_processors(
data=job.numpy.extract_variable(*funct_args)
)
if MPI.COMM_WORLD.rank == 0:
return np.array(data)
else:
if MPI.COMM_WORLD.rank == 0:
# if type is 1 - reformat file
try:
data = job.extract_variable(*funct_args)
except ValueError:
return []
return data


def get_natoms(funct_args):
Expand Down Expand Up @@ -472,6 +472,16 @@ def select_cmd(argument):
return switcher.get(argument)


def _gather_data_from_all_processors(data):
data_gather = MPI.COMM_WORLD.gather(data, root=0)
if MPI.COMM_WORLD.rank == 0:
data = []
for vl in data_gather:
for v in vl:
data.append(v)
return data


if __name__ == "__main__":
while True:
if MPI.COMM_WORLD.rank == 0:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pylammpsmpi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_extract_variable(self):
self.assertEqual(np.round(x, 2), 1.13)

x = self.lmp.extract_variable("fx", "all", 1)
self.assertEqual(len(x), 128)
self.assertEqual(len(x), 256)
self.assertEqual(np.round(x[0], 2), -0.26)

def test_scatter_atoms(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pylammpsmpi_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_extract_variable(self):
x = self.lmp.extract_variable("tt", "all", 0)
self.assertEqual(np.round(x, 2), 1.13)
x = self.lmp.extract_variable("fx", "all", 1)
self.assertEqual(len(x), 128)
self.assertEqual(len(x), 256)
self.assertEqual(np.round(x[0], 2), -0.26)

def test_scatter_atoms(self):
Expand Down

0 comments on commit 467d2f2

Please sign in to comment.