Skip to content

Commit

Permalink
Merge pull request #630 from OceanParcels/fixing_particlefile_MPI_mode
Browse files Browse the repository at this point in the history
Some fixes to the ParticleFile for MPI mode
  • Loading branch information
erikvansebille authored Aug 22, 2019
2 parents 6ff69e8 + 3ae397b commit c01b4d2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
26 changes: 16 additions & 10 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,17 @@ def open_netcdf_file(self, data_shape):
setattr(self.dataset, name, message)

def __del__(self):
# The export can only start when all threads are done.
if MPI:
MPI.COMM_WORLD.Barrier()
rank = MPI.COMM_WORLD.Get_rank()
else:
rank = 0
if self.convert_at_end and rank == 0: # only export once.
if self.convert_at_end:
self.close()

def close(self, delete_tempfiles=True):
"""Close the ParticleFile object by exporting and then deleting
the temporary npy files"""
self.export()
if delete_tempfiles:
self.delete_tempwritedir(tempwritedir=self.tempwritedir_base)
self.dataset.close()
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
if mpi_rank == 0:
if delete_tempfiles:
self.delete_tempwritedir(tempwritedir=self.tempwritedir_base)
self.convert_at_end = False

def add_metadata(self, name, message):
Expand Down Expand Up @@ -336,6 +331,12 @@ def read_from_npy(self, file_list, time_steps, var):
def export(self):
"""Exports outputs in temporary NPY-files to NetCDF file"""

if MPI:
# The export can only start when all threads are done.
MPI.COMM_WORLD.Barrier()
if MPI.COMM_WORLD.Get_rank() > 0:
return # export only on threat 0

# Retrieve all temporary writing directories and sort them in numerical order
temp_names = sorted(glob(os.path.join("%s" % self.tempwritedir_base, "*")),
key=lambda x: int(os.path.basename(x)))
Expand All @@ -344,17 +345,20 @@ def export(self):
raise RuntimeError("No npy files found in %s" % self.tempwritedir_base)

global_maxid_written = -1
global_time_written = []
global_file_list = []
if len(self.var_names_once) > 0:
global_file_list_once = []
for tempwritedir in temp_names:
if os.path.exists(tempwritedir):
pset_info_local = np.load(os.path.join(tempwritedir, 'pset_info.npy'), allow_pickle=True).item()
global_maxid_written = np.max([global_maxid_written, pset_info_local['maxid_written']])
global_time_written += pset_info_local['time_written']
global_file_list += pset_info_local['file_list']
if len(self.var_names_once) > 0:
global_file_list_once += pset_info_local['file_list_once']
self.maxid_written = global_maxid_written
self.time_written = np.unique(global_time_written)

for var in self.var_names:
data = self.read_from_npy(global_file_list, len(self.time_written), var)
Expand All @@ -367,6 +371,8 @@ def export(self):
for var in self.var_names_once:
getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var)

self.dataset.close()

def delete_tempwritedir(self, tempwritedir=None):
"""Deleted all temporary npy files
:param tempwritedir Optional path of the directory to delete
Expand Down
1 change: 0 additions & 1 deletion tests/test_particle_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def close_and_compare_netcdffiles(filepath, ofile, assystemcall=False):

ofile.name = filepath + 'b.nc'
ofile.export()
ofile.dataset.close()
ncfile2 = Dataset(filepath + 'b.nc', 'r', 'NETCDF4')

for v in ncfile2.variables.keys():
Expand Down

0 comments on commit c01b4d2

Please sign in to comment.