generated from ssciwr/python-project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload.py
232 lines (191 loc) · 9.24 KB
/
load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import h5py
import numpy as np
def print_hdf5_file_structure(file_path):
with h5py.File(file_path, 'r') as f:
print(f"File: {file_path}")
print_hdf5_group_structure(f)
def print_hdf5_group_structure(group, indent=0):
for key in group.keys():
sub_group = group[key]
if isinstance(sub_group, h5py.Group):
print("{0}Group: {1}".format(" " * indent, key))
print_hdf5_group_structure(sub_group, indent + 4)
else:
print("{0}Dataset: {1} ({2}) ({3})".format(" " * indent, key, sub_group.dtype, sub_group.shape))
# TODO: maybe add get_galaxy function to get all the data of a specific galaxy
class Gamma():
'''Class to load the data generated by the generate_data.py script.
Parameters:
-----------
path : str
The path to the hdf5 file generated by the generate_data.py script.
show_structure : bool, optional
If True, print the structure of the hdf5 file.
m_min : float, optional
The minimum mass of the galaxies to load. If None, no filter is applied.
m_max : float, optional
The maximum mass of the galaxies to load. If None, no filter is applied.
Attributes:
-----------
path : str
The path to the hdf5 file generated by the generate_data.py script.
mask : np.ndarray
The mask to filter the data. If None, no mask is applied.
_galaxy_attributes_keys : list
The keys of the galaxy attributes.
_particles_keys : list
The keys of the particles.
_image_fields : dict
The fields of the images.
'''
def __init__(self,path, show_structure = True, m_min = None, m_max = None):
self.path = path
self._load_keys()
if show_structure:
self.show_structure()
# If m_min and/or m_max are specified, filter the data
if m_min is not None or m_max is not None:
self.mask = self._create_mass_mask(m_min, m_max)
def __getitem__(self, key):
with h5py.File(self.path, 'r') as f:
return f[key][()]
def _load_keys(self):
# Load the data from the hdf5 file generated by the generate_data.py script without loading into memory
with h5py.File(self.path, 'r') as f:
self._galaxy_attributes_keys = [keys for keys in f["Galaxies/Attributes"].keys()]
self._particles_keys = [keys for keys in f["Galaxies/Particles"].keys()]
# Load the fields of the images
self._image_fields = dict()
for particle in f["Galaxies/Particles"].keys():
particle_field = dict()
for dim in f["Galaxies/Particles"][particle]["Images"].keys():
particle_field[dim]= [keys for keys in f["Galaxies/Particles"][particle]["Images"][dim].keys()]
self._image_fields[particle] = particle_field
def get_attribute(self, attribute, index = None, ignore_mask = False):
'''
Get a galaxy attribute.
If index is None, first check if a mask is set, and then return the specific attribute of all galaxies.
Otherwise, return the attribute of the specified index in the dataset.
Parameters:
-----------
attribute : str
The attribute to return. Must be a valid attribute, otherwise a ValueError is raised.
index : int, optional
The galaxy index in the dataset to return. If None, return all galaxies.
ignore_mask : bool, optional
Whether to ignore the mask set for the dataset. Default is False.
Returns:
--------
attribute : np.ndarray
The attribute of the specified galaxy index in the dataset.
Examples:
---------
>>> data = Gamma("data.hdf5")
>>> data.get_attribute("mass") # Get the mass of all galaxies in the dataset
>>> data.get_attribute("mass", 10) # Get the mass of the 10th galaxy in the dataset
Load Data with a mask
>>> data = Gamma("data.hdf5", m_min=10, m_max=11)
>>> data.get_attribute("mass") # Get the mass of all galaxies in the dataset with mass between 10^10 and 10^11
'''
#Check if the attribute is valid
if attribute not in self._galaxy_attributes_keys:
raise ValueError(f"Attribute {attribute} not found. Valid attributes are: {self._galaxy_attributes_keys}")
# Open the hdf5 file in read-only mode
with h5py.File(self.path, 'r') as f:
if index is None:
# Check if mask is defined
if hasattr(self, 'mask') and not ignore_mask:
return f["Galaxies/Attributes"][attribute][self.mask]
else:
return f["Galaxies/Attributes"][attribute][()]
else:
return f["Galaxies/Attributes"][attribute][index]
def get_image(self, particle_type, field, index=None, ignore_mask = False, dim=2):
'''
Get the image of the specified particle type and field.
If index is None, return all images. Otherwise, return the galaxy image of the specified index and field in the dataset.
Parameters:
-----------
particle_type : str
The particle type of the image.
field : str
The field of the image.
index : int, optional
The galaxy index ind the dataset to return. If None, return all images.
dim : int, optional
The dimension of the image. If 2, return a 2D image. If 3, return a 3D image.
Returns:
--------
image : np.ndarray
The image of the specified particle type and field.
Examples:
---------
>>> data = Gamma("data.hdf5")
>>> image = data.get_image("stars", "Masses", 10) # Get the stars masses image of the 10th galaxy in the dataset
>>> all_images = data.get_image("stars", "Masses") # Get all stars masses images in the dataset
'''
# check if dimension is string or int
if isinstance(dim, str):
dimension = dim
else:
dimension = "dim2" if dim == 2 else "dim3"
#Check if the particle type is valid
if particle_type not in self._particles_keys:
raise ValueError(f"Particle type {particle_type} not found. Valid particle types are: {self._particles_keys}")
#Check if the field is valid
if field not in self._image_fields[particle_type][dimension]:
raise ValueError(f"Field {field} not found. Valid fields are: {self._image_fields[particle_type]}")
# OPen the hdf5 file in read-only mode
with h5py.File(self.path, 'r') as f:
if index is None:
# Check if mask is defined
if hasattr(self, 'mask') and not ignore_mask:
return f[f"Galaxies/Particles/{particle_type}/Images/{dimension}/{field}"][self.mask]
else:
return f[f"Galaxies/Particles/{particle_type}/Images/{dimension}/{field}"][()]
else:
return f[f"Galaxies/Particles/{particle_type}/Images/{dimension}/{field}"][index]
def show_structure(self):
'''
Print the structure of the HDF5 file.
This method is useful to check the structure of the HDF5 file and the available attributes and images.
Examples:
---------
>>> data = Gamma("data.hdf5")
>>> data.show_structure()
File: data.hdf5
Group: Galaxies
Group: Attributes
Dataset: halo_id (int64)
Dataset: mass (float64)
Group: Particles
Group: stars
Group: Images
Dataset: Masses (np.ndarray)
Dataset: GFM_Metallicity (np.ndarray)
Group: gas
Group: Images
Dataset: Masses (np.ndarray)
'''
print_hdf5_file_structure(self.path)
def _create_mass_mask(self, m_min= None, m_max=None):
'''
Create a mask to filter the data based on the mass of the galaxies. Not to be used directly, use the get_attribute method instead.
The parameters should be in log10(Msun) units.
Parameters:
-----------
m_min : float, optional
The minimum mass of the galaxies to keep. If None, the minimum mass of the dataset is used.
m_max : float, optional
The maximum mass of the galaxies to keep. If None, the maximum mass of the dataset is used.
Returns:
--------
mask : np.ndarray
The mask to apply to the data to filter the galaxies based on their mass.
'''
masses = np.log10(self.get_attribute("mass"))
if m_min is None:
m_min = np.min(masses)
if m_max is None:
m_max = np.max(masses)
return ((masses >= m_min) & (masses <= m_max) )