diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index bf11caf..d067d79 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -77,15 +77,10 @@ def convert_data(val, type, length, width): val = job.extract_compute(*filtered_args) return convert_data(val=val, type=type, length=length, width=width) elif style == 1: # per atom property - val = job.numpy.extract_compute(*filtered_args) - val_gather = MPI.COMM_WORLD.gather(val, root=0) + val = _gather_data_from_all_processors( + data=job.numpy.extract_compute(*filtered_args) + ) if MPI.COMM_WORLD.rank == 0: - # val_gather.shape [number of cores, atoms on specific core] - # the number of atoms on specific cores can vary - val = [] - for vl in val_gather: - for v in vl: - val.append(v) length = job.get_natoms() return convert_data(val=val, type=type, length=length, width=width) else: # Todo @@ -165,15 +160,20 @@ def extract_fix(funct_args): def extract_variable(funct_args): # in the args - if the third one, # which is the type is 1 - a lammps array is returned - if MPI.COMM_WORLD.rank == 0: - # if type is 1 - reformat file - try: - data = job.extract_variable(*funct_args) - except ValueError: - return [] - if funct_args[2] == 1: - data = np.array(data) - return data + if funct_args[2] == 1: + data = _gather_data_from_all_processors( + data=job.numpy.extract_variable(*funct_args) + ) + if MPI.COMM_WORLD.rank == 0: + return np.array(data) + else: + if MPI.COMM_WORLD.rank == 0: + # if type is 1 - reformat file + try: + data = job.extract_variable(*funct_args) + except ValueError: + return [] + return data def get_natoms(funct_args): @@ -472,6 +472,16 @@ def select_cmd(argument): return switcher.get(argument) +def _gather_data_from_all_processors(data): + data_gather = MPI.COMM_WORLD.gather(data, root=0) + if MPI.COMM_WORLD.rank == 0: + data = [] + for vl in data_gather: + for v in vl: + data.append(v) + return data + + if __name__ == "__main__": while True: if MPI.COMM_WORLD.rank == 0: diff --git a/tests/test_pylammpsmpi_cluster.py b/tests/test_pylammpsmpi_cluster.py index d1a7c01..356fe6f 100644 --- a/tests/test_pylammpsmpi_cluster.py +++ b/tests/test_pylammpsmpi_cluster.py @@ -59,7 +59,7 @@ def test_extract_variable(self): self.assertEqual(np.round(x, 2), 1.13) x = self.lmp.extract_variable("fx", "all", 1) - self.assertEqual(len(x), 128) + self.assertEqual(len(x), 256) self.assertEqual(np.round(x[0], 2), -0.26) def test_scatter_atoms(self): diff --git a/tests/test_pylammpsmpi_local.py b/tests/test_pylammpsmpi_local.py index e63ad8f..f2a831f 100644 --- a/tests/test_pylammpsmpi_local.py +++ b/tests/test_pylammpsmpi_local.py @@ -55,7 +55,7 @@ def test_extract_variable(self): x = self.lmp.extract_variable("tt", "all", 0) self.assertEqual(np.round(x, 2), 1.13) x = self.lmp.extract_variable("fx", "all", 1) - self.assertEqual(len(x), 128) + self.assertEqual(len(x), 256) self.assertEqual(np.round(x[0], 2), -0.26) def test_scatter_atoms(self):