Skip to content

Commit

Permalink
Fix typo in velocity loading; update plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
julballa committed Jun 3, 2024
1 parent ea299fb commit 137a31f
Show file tree
Hide file tree
Showing 3 changed files with 443 additions and 10 deletions.
10 changes: 5 additions & 5 deletions benchmarks/galaxies/dataset_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


# Mean and std for halos and cosmological parameters
MEAN_HALOS_DICT = {'x': 499.91877075908684, 'y': 500.0947802559321, 'z': 499.964508664328,'Jx': 212560050888254.06,'Jy': 349712732356652.25, 'Jz': -100259775332585.12, 'vx': -0.0512854365234889, 'vy': -0.01263126442198149, 'vz': -0.06458034372345466, 'M200c': 321308383763206.9, 'Rvir': 1424.4071655758826}
STD_HALOS_DICT = {'x': 288.71092533309235, 'y': 288.7525818573022, 'z': 288.70234893905575, 'Jx': 2.4294356933448945e+18, 'Jy': 2.3490019110577966e+18, 'Jz': 2.406422979830857e+18, 'vx': 344.0231468131901, 'vy': 343.9333673335964, 'vz': 344.071876710777, 'M200c': 405180433634974.75, 'Rvir': 298.14502916425675}
MEAN_HALOS_DICT = {'x': 499.91877075908684, 'y': 500.0947802559321, 'z': 499.964508664328,'Jx': 212560050888254.06,'Jy': 349712732356652.25, 'Jz': -100259775332585.12, 'v_x': -0.0512854365234889, 'v_y': -0.01263126442198149, 'v_z': -0.06458034372345466, 'M200c': 321308383763206.9, 'Rvir': 1424.4071655758826}
STD_HALOS_DICT = {'x': 288.71092533309235, 'y': 288.7525818573022, 'z': 288.70234893905575, 'Jx': 2.4294356933448945e+18, 'Jy': 2.3490019110577966e+18, 'Jz': 2.406422979830857e+18, 'v_x': 344.0231468131901, 'v_y': 343.9333673335964, 'v_z': 344.071876710777, 'M200c': 405180433634974.75, 'Rvir': 298.14502916425675}
MEAN_PARAMS_DICT = {'Omega_m': 0.29994175, 'Omega_b': 0.049990308, 'h': 0.69996387, 'n_s': 0.9999161, 'sigma_8': 0.7999111}
STD_PARAMS_DICT = {'Omega_m': 0.11547888, 'Omega_b': 0.017312417, 'h': 0.11543678, 'n_s': 0.115482554, 'sigma_8': 0.11545073}
MEAN_TPCF_VEC = [1.47385902e+01, 4.52754450e+00, 1.89688166e+00, 1.00795493e+00,
Expand All @@ -22,7 +22,7 @@
0.16773013, 0.15968612, 0.15186733, 0.14234885, 0.13153203, 0.11954234,
0.10549666, 0.09024256, 0.07655078, 0.06350282, 0.05210615, 0.0426435]

def _parse_function(proto, features=['x', 'y', 'z', 'Jx', 'Jy', 'Jz', 'vx', 'vy', 'vz', 'M200c'],
def _parse_function(proto, features=['x', 'y', 'z', 'Jx', 'Jy', 'Jz', 'v_x', 'v_y', 'v_z', 'M200c'],
params=['Omega_m', 'Omega_b', 'h', 'n_s', 'sigma_8'],
include_tpcf=False):

Expand Down Expand Up @@ -56,12 +56,12 @@ def _parse_function(proto, features=['x', 'y', 'z', 'Jx', 'Jy', 'Jz', 'vx', 'vy'
def get_halo_dataset(batch_size=64,
num_samples=None, # If not None, only return this many samples
split='train',
features=['x', 'y', 'z', 'Jx', 'Jy', 'Jz', 'vx', 'vy', 'vz', 'M200c', 'Rvir'],
features=['x', 'y', 'z', 'Jx', 'Jy', 'Jz', 'v_x', 'v_y', 'v_z', 'M200c', 'Rvir'],
params=['Omega_m', 'sigma_8'],
return_mean_std=False,
standardize=True,
seed=42,
tfrecords_path= '/quijote_tfrecords_consistent_trees',
tfrecords_path= '/home/jballa/galaxies/quijote_tfrecords_consistent_trees',
include_tpcf=False
):

Expand Down
433 changes: 433 additions & 0 deletions benchmarks/galaxies/plotting.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions benchmarks/galaxies/train_cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def __call__(self, x):
st_graph = get_equivariant_graph(
node_features=nodes,
positions=positions,
velocities=None,
steerable_velocities=False,
velocities=velocities,
steerable_velocities=True,
senders=x.senders,
receivers=x.receivers,
n_node=x.n_node,
Expand Down Expand Up @@ -337,7 +337,7 @@ def run_expt(
if feats == 'pos':
features = ['x', 'y', 'z']
elif feats == 'all':
features = ['x', 'y', 'z', 'vx', 'vy', 'vz']
features = ['x', 'y', 'z', 'v_x', 'v_y', 'v_z']
else:
raise NotImplementedError

Expand Down Expand Up @@ -381,9 +381,9 @@ def run_expt(
print('Train-Val-Test split:', n_train, n_val, n_test)

if use_tpcf == "small":
tpcf_idx = list(range(8))
tpcf_idx = list(range(6))
elif use_tpcf == "large":
tpcf_idx = list(range(15, 24))
tpcf_idx = list(range(13, 24))
else:
tpcf_idx = list(range(24))

Expand Down

0 comments on commit 137a31f

Please sign in to comment.