Skip to content

Commit

Permalink
Merge pull request #2146 from ngoldbaum/all-data-speedup
Browse files Browse the repository at this point in the history
[yt-4.0] optimize accessing data through ds.all_data() for SPH
  • Loading branch information
Nathan Goldbaum authored Feb 22, 2019
2 parents d6df5dd + 0548506 commit 14ec989
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 31 deletions.
44 changes: 25 additions & 19 deletions yt/frontends/gadget/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,17 @@ def _read_particle_fields(self, chunks, ptf, selector):
if data_file.total_particles[ptype] == 0:
continue
g = f["/%s" % ptype]
coords = g["Coordinates"][si:ei].astype("float64")
if ptype == 'PartType0':
hsmls = g["SmoothingLength"][si:ei].astype("float64")
if getattr(selector, 'is_all_data', False):
mask = slice(None, None, None)
else:
hsmls = 0.0
mask = selector.select_points(
coords[:,0], coords[:,1], coords[:,2], hsmls)
del coords
coords = g["Coordinates"][si:ei].astype("float64")
if ptype == 'PartType0':
hsmls = g["SmoothingLength"][si:ei].astype("float64")
else:
hsmls = 0.0
mask = selector.select_points(
coords[:,0], coords[:,1], coords[:,2], hsmls)
del coords
if mask is None:
continue
for field in field_list:
Expand Down Expand Up @@ -309,19 +312,22 @@ def _read_particle_fields(self, chunks, ptf, selector):
tp = data_file.total_particles
f = open(data_file.filename, "rb")
for ptype, field_list in sorted(ptf.items()):
f.seek(poff[ptype, "Coordinates"], os.SEEK_SET)
pos = self._read_field_from_file(
f, tp[ptype], "Coordinates")
if ptype == self.ds._sph_ptype:
f.seek(poff[ptype, "SmoothingLength"], os.SEEK_SET)
hsml = self._read_field_from_file(
f, tp[ptype], "SmoothingLength")
if getattr(selector, 'is_all_data', False):
mask = slice(None, None, None)
else:
hsml = 0.0
mask = selector.select_points(
pos[:, 0], pos[:, 1], pos[:, 2], hsml)
del pos
del hsml
f.seek(poff[ptype, "Coordinates"], os.SEEK_SET)
pos = self._read_field_from_file(
f, tp[ptype], "Coordinates")
if ptype == self.ds._sph_ptype:
f.seek(poff[ptype, "SmoothingLength"], os.SEEK_SET)
hsml = self._read_field_from_file(
f, tp[ptype], "SmoothingLength")
else:
hsml = 0.0
mask = selector.select_points(
pos[:, 0], pos[:, 1], pos[:, 2], hsml)
del pos
del hsml
if mask is None:
continue
for field in field_list:
Expand Down
15 changes: 13 additions & 2 deletions yt/frontends/sph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ class IOHandlerSPH(BaseIOHandler):
"""

def _count_particles_chunks(self, psize, chunks, ptf, selector):
for ptype, (x, y, z), hsml in self._read_particle_coords(chunks, ptf):
psize[ptype] += selector.count_points(x, y, z, hsml)
if getattr(selector, 'is_all_data', False):
chunks = list(chunks)
data_files = set([])
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = sorted(data_files, key=lambda x: (x.filename, x.start))
for data_file in data_files:
for ptype in ptf.keys():
psize[ptype] += data_file.total_particles[ptype]
else:
for ptype, (x, y, z), hsml in self._read_particle_coords(chunks, ptf):
psize[ptype] += selector.count_points(x, y, z, hsml)
return dict(psize)
18 changes: 11 additions & 7 deletions yt/frontends/tipsy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,18 @@ def _read_particle_fields(self, chunks, ptf, selector):
auxdata.append(aux)
if afields:
p = append_fields(p, afields, auxdata)
x = p["Coordinates"]['x'].astype("float64")
y = p["Coordinates"]['y'].astype("float64")
z = p["Coordinates"]['z'].astype("float64")
if ptype == 'Gas':
hsml = self._read_smoothing_length(data_file, count)
if getattr(selector, 'is_all_data', False):
mask = slice(None, None, None)
else:
hsml = 0.
mask = selector.select_points(x, y, z, hsml)
x = p["Coordinates"]['x'].astype("float64")
y = p["Coordinates"]['y'].astype("float64")
z = p["Coordinates"]['z'].astype("float64")
if ptype == 'Gas':
hsml = self._read_smoothing_length(data_file, count)
else:
hsml = 0.
mask = selector.select_points(x, y, z, hsml)
del x, y, z, hsml
if mask is None:
continue
tf = self._fill_fields(field_list, p, hsml, mask, data_file)
Expand Down
10 changes: 7 additions & 3 deletions yt/geometry/particle_geometry_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ def _identify_base_chunk(self, dobj):
dobj._chunk_info = [dobj]
else:
# TODO: only return files
dfi, file_masks, addfi = self.regions.identify_file_masks(
dobj.selector)
nfiles = len(file_masks)
if getattr(dobj.selector, 'is_all_data', False):
dfi, file_masks, addfi = self.regions.identify_file_masks(
dobj.selector)
nfiles = len(file_masks)
else:
nfiles = self.regions.nfiles
dfi = np.arange(nfiles)
dobj._chunk_info = [None for _ in range(nfiles)]
for i in range(nfiles):
domain_id = i+1
Expand Down
7 changes: 7 additions & 0 deletions yt/geometry/selection_routines.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ cdef class RegionSelector(SelectorObject):
cdef np.float64_t left_edge[3]
cdef np.float64_t right_edge[3]
cdef np.float64_t right_edge_shift[3]
cdef public bint is_all_data
cdef bint loose_selection
cdef bint check_period[3]

Expand All @@ -958,6 +959,12 @@ cdef class RegionSelector(SelectorObject):
cdef np.float64_t[:] DW = _ensure_code(dobj.ds.domain_width)
cdef np.float64_t[:] DLE = _ensure_code(dobj.ds.domain_left_edge)
cdef np.float64_t[:] DRE = _ensure_code(dobj.ds.domain_right_edge)
le_all = (np.array(LE) == dobj.ds.domain_left_edge).all()
re_all = (np.array(RE) == dobj.ds.domain_right_edge).all()
if le_all and re_all:
self.is_all_data = True
else:
self.is_all_data = False
cdef np.float64_t region_width[3]
cdef bint p[3]
# This is for if we want to include zones that overlap and whose
Expand Down

0 comments on commit 14ec989

Please sign in to comment.