Skip to content

Commit

Permalink
Merge pull request #4661 from cphyc/rockstar-members
Browse files Browse the repository at this point in the history
[ENH] Read Rockstar members
  • Loading branch information
matthewturk authored Oct 4, 2024
2 parents ded48a3 + af60658 commit 1509b29
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 2 deletions.
120 changes: 118 additions & 2 deletions yt/frontends/rockstar/data_structures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import glob
import os
from functools import cached_property
from typing import Any, Optional

import numpy as np

Expand All @@ -9,22 +11,58 @@
from yt.geometry.particle_geometry_handler import ParticleIndex
from yt.utilities import fortran_utils as fpu
from yt.utilities.cosmology import Cosmology
from yt.utilities.exceptions import YTFieldNotFound

from .definitions import header_dt
from .fields import RockstarFieldInfo


class RockstarBinaryFile(HaloCatalogFile):
header: dict
_position_offset: int
_member_offset: int
_Npart: "np.ndarray[Any, np.dtype[np.int64]]"
_ids_halos: list[int]
_file_size: int

def __init__(self, ds, io, filename, file_id, range):
with open(filename, "rb") as f:
self.header = fpu.read_cattrs(f, header_dt, "=")
self._position_offset = f.tell()
pcount = self.header["num_halos"]

halos = np.fromfile(f, dtype=io._halo_dt, count=pcount)
self._member_offset = f.tell()
self._ids_halos = list(halos["particle_identifier"])
self._Npart = halos["num_p"]

f.seek(0, os.SEEK_END)
self._file_size = f.tell()

expected_end = self._member_offset + 8 * self._Npart.sum()
if expected_end != self._file_size:
raise RuntimeError(
f"File size {self._file_size} does not match expected size {expected_end}."
)

super().__init__(ds, io, filename, file_id, range)

def _read_particle_positions(self, ptype, f=None):
def _read_member(
self, ihalo: int
) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]:
if ihalo not in self._ids_halos:
return None

ind_halo = self._ids_halos.index(ihalo)

ipos = self._member_offset + 8 * self._Npart[:ind_halo].sum()

with open(self.filename, "rb") as f:
f.seek(ipos, os.SEEK_SET)
ids = np.fromfile(f, dtype=np.int64, count=self._Npart[ind_halo])
return ids

def _read_particle_positions(self, ptype: str, f=None):
"""
Read all particle positions in this file.
"""
Expand All @@ -48,8 +86,18 @@ def _read_particle_positions(self, ptype, f=None):
return pos


class RockstarIndex(ParticleIndex):
def get_member(self, ihalo: int):
for df in self.data_files:
members = df._read_member(ihalo)
if members is not None:
return members

raise RuntimeError(f"Could not find halo {ihalo} in any data file.")


class RockstarDataset(ParticleDataset):
_index_class = ParticleIndex
_index_class = RockstarIndex
_file_class = RockstarBinaryFile
_field_info_class = RockstarFieldInfo
_suffix = ".bin"
Expand Down Expand Up @@ -122,3 +170,71 @@ def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
return False
else:
return header["magic"] == 18077126535843729616

def halo(self, ptype, particle_identifier):
return RockstarHaloContainer(
ptype,
particle_identifier,
parent_ds=None,
halo_ds=self,
)


class RockstarHaloContainer:
def __init__(self, ptype, particle_identifier, *, parent_ds, halo_ds):
if ptype not in halo_ds.particle_types_raw:
raise RuntimeError(
f'Possible halo types are {halo_ds.particle_types_raw}, supplied "{ptype}".'
)

self.ds = parent_ds
self.halo_ds = halo_ds
self.ptype = ptype
self.particle_identifier = particle_identifier

def __repr__(self):
return f"{self.halo_ds}_{self.ptype}_{self.particle_identifier:09d}"

def __getitem__(self, key):
if isinstance(key, tuple):
ptype, field = key
else:
ptype = self.ptype
field = key

data = {
"mass": self.mass,
"position": self.position,
"velocity": self.velocity,
"member_ids": self.member_ids,
}
if ptype == "halos" and field in data:
return data[field]

raise YTFieldNotFound((ptype, field), dataset=self.ds)

@cached_property
def ihalo(self):
halo_id = self.particle_identifier
halo_ids = list(self.halo_ds.r["halos", "particle_identifier"].astype("i8"))
ihalo = halo_ids.index(halo_id)

assert halo_ids[ihalo] == halo_id

return ihalo

@property
def mass(self):
return self.halo_ds.r["halos", "particle_mass"][self.ihalo]

@property
def position(self):
return self.halo_ds.r["halos", "particle_position"][self.ihalo]

@property
def velocity(self):
return self.halo_ds.r["halos", "particle_velocity"][self.ihalo]

@property
def member_ids(self):
return self.halo_ds.index.get_member(self.particle_identifier)
20 changes: 20 additions & 0 deletions yt/frontends/rockstar/tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,23 @@ def test_particle_selection():
ds = data_dir_load(r1)
psc = ParticleSelectionComparison(ds)
psc.run_defaults()


@requires_file(r1)
def test_halo_loading():
ds = data_dir_load(r1)

for halo_id, Npart in zip(
ds.r["halos", "particle_identifier"],
ds.r["halos", "num_p"],
):
halo = ds.halo("halos", halo_id)
assert halo is not None

# Try accessing properties
halo.position
halo.velocity
halo.mass

# Make sure we can access the member particles
assert_equal(len(halo.member_ids), Npart)

0 comments on commit 1509b29

Please sign in to comment.