1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import warnings
6
+ import numpy as np
7
+ import shutil
8
+
9
+ import copy
10
+ from typing import Union
11
+ from tqdm import tqdm # TODO remove for final push
12
+
13
+ from lips .dataset .dataSet import DataSet
14
+ from lips .logger import CustomLogger
15
+
16
+ class RollingWheelDataSet (DataSet ):
17
+ """
18
+ This specific DataSet uses Getfem framework to simulate data arising from a rolling wheel problem.
19
+ """
20
+
21
+ def __init__ (self ,
22
+ name = "train" ,
23
+ attr_names = ("disp" ,),
24
+ log_path : Union [str , None ]= None
25
+ ):
26
+ DataSet .__init__ (self , name = name )
27
+ self ._attr_names = copy .deepcopy (attr_names )
28
+ self .size = 0
29
+
30
+ # logger
31
+ self .logger = CustomLogger (__class__ .__name__ , log_path ).logger
32
+
33
+ def generate (self ,
34
+ simulator : "GetfemSimulator" ,
35
+ actor ,
36
+ path_out ,
37
+ nb_samples ,
38
+ simulator_seed : Union [None , int ] = None ,
39
+ actor_seed : Union [None , int ] = None ):
40
+ """
41
+ For this dataset, we use a GetfemSimulator and a Sampler to generate data from a rolling wheel.
42
+
43
+ Parameters
44
+ ----------
45
+ simulator:
46
+ In this case, this should be a grid2op environment
47
+
48
+ actor:
49
+ In this case, it is the sampler used for the input parameters space discretization
50
+
51
+ path_out:
52
+ The path where the data will be saved
53
+
54
+ nb_samples:
55
+ Number of rows (examples) in the final dataset
56
+
57
+ simulator_seed:
58
+ Seed used to set the simulator for reproducible experiments
59
+
60
+ actor_seed:
61
+ Seed used to set the actor for reproducible experiments
62
+
63
+ Returns
64
+ -------
65
+
66
+ """
67
+ try :
68
+ import getfem
69
+ except ImportError as exc_ :
70
+ raise RuntimeError ("Impossible to `generate` rolling wheel datet if you don't have "
71
+ "the getfem package installed" ) from exc_
72
+ if nb_samples <= 0 :
73
+ raise RuntimeError ("Impossible to generate a negative number of data." )
74
+
75
+ samples = actor .generate_samples (nb_samples = nb_samples ,sampler_seed = actor_seed )
76
+ self ._init_store_data (simulator = simulator ,nb_samples = nb_samples )
77
+
78
+ for current_size ,sample in enumerate (tqdm (samples , desc = self .name )):
79
+ simulator = type (simulator )(simulatorInstance = simulator )
80
+ simulator .modify_state (actor = sample )
81
+ simulator .build_model ()
82
+ solverState = simulator .run_problem ()
83
+
84
+ self ._store_obs (current_size = current_size ,obs = simulator )
85
+
86
+ self .size = nb_samples
87
+
88
+ if path_out is not None :
89
+ # I should save the data
90
+ self ._save_internal_data (path_out )
91
+
92
+ def _init_store_data (self ,simulator ,nb_samples ):
93
+ self .data = dict ()
94
+ for attr_nm in self ._attr_names :
95
+ array_ = simulator .get_variable_value (field_name = attr_nm )
96
+ truc = np .zeros ((nb_samples , array_ .shape [0 ]))
97
+ self .data [attr_nm ] = np .zeros ((nb_samples , array_ .shape [0 ]), dtype = array_ .dtype )
98
+
99
+ def _store_obs (self , current_size , obs ):
100
+ for attr_nm in self ._attr_names :
101
+ array_ = obs .get_solution (field_name = attr_nm )
102
+ self .data [attr_nm ][current_size , :] = array_
103
+
104
+ def _save_internal_data (self , path_out ):
105
+ """save the self.data in a proper format"""
106
+ full_path_out = os .path .join (os .path .abspath (path_out ), self .name )
107
+
108
+ if not os .path .exists (os .path .abspath (path_out )):
109
+ os .mkdir (os .path .abspath (path_out ))
110
+ # TODO logger
111
+ #print(f"Creating the path {path_out} to store the datasets [data will be stored under {full_path_out}]")
112
+ self .logger .info (f"Creating the path { path_out } to store the datasets [data will be stored under { full_path_out } ]" )
113
+
114
+ if os .path .exists (full_path_out ):
115
+ # deleting previous saved data
116
+ # TODO logger
117
+ #print(f"Deleting previous run at {full_path_out}")
118
+ self .logger .warning (f"Deleting previous run at { full_path_out } " )
119
+ shutil .rmtree (full_path_out )
120
+
121
+ os .mkdir (full_path_out )
122
+ # TODO logger
123
+ #print(f"Creating the path {full_path_out} to store the dataset name {self.name}")
124
+ self .logger .info (f"Creating the path { full_path_out } to store the dataset name { self .name } " )
125
+
126
+ for attr_nm in self ._attr_names :
127
+ np .savez_compressed (f"{ os .path .join (full_path_out , attr_nm )} .npz" , data = self .data [attr_nm ])
128
+
129
+ def load (self , path ):
130
+ if not os .path .exists (path ):
131
+ raise RuntimeError (f"{ path } cannot be found on your computer" )
132
+ if not os .path .isdir (path ):
133
+ raise RuntimeError (f"{ path } is not a valid directory" )
134
+ full_path = os .path .join (path , self .name )
135
+ if not os .path .exists (full_path ):
136
+ raise RuntimeError (f"There is no data saved in { full_path } . Have you called `dataset.generate()` with "
137
+ f"a given `path_out` ?" )
138
+ #for attr_nm in (*self._attr_names, *self._theta_attr_names):
139
+ for attr_nm in self ._attr_names :
140
+ path_this_array = f"{ os .path .join (full_path , attr_nm )} .npz"
141
+ if not os .path .exists (path_this_array ):
142
+ raise RuntimeError (f"Impossible to load data { attr_nm } . Have you called `dataset.generate()` with "
143
+ f"a given `path_out` and such that `dataset` is built with the right `attr_names` ?" )
144
+
145
+ if self .data is not None :
146
+ warnings .warn (f"Deleting previous run in attempting to load the new one located at { path } " )
147
+ self .data = {}
148
+ self .size = None
149
+ #for attr_nm in (*self._attr_names, *self._theta_attr_names):
150
+ for attr_nm in self ._attr_names :
151
+ path_this_array = f"{ os .path .join (full_path , attr_nm )} .npz"
152
+ self .data [attr_nm ] = np .load (path_this_array )["data" ]
153
+ self .size = self .data [attr_nm ].shape [0 ]
154
+
155
+ def get_data (self , index ):
156
+ """
157
+ This function returns the data in the data that match the index `index`
158
+
159
+ Parameters
160
+ ----------
161
+ index:
162
+ A list of integer
163
+
164
+ Returns
165
+ -------
166
+
167
+ """
168
+ super ().get_data (index ) # check that everything is legit
169
+
170
+ # make sure the index are numpy array
171
+ if isinstance (index , list ):
172
+ index = np .array (index , dtype = int )
173
+ elif isinstance (index , int ):
174
+ index = np .array ([index ], dtype = int )
175
+
176
+ # init the results
177
+ res = {}
178
+ nb_sample = index .size
179
+ #for el in (*self._attr_names, *self._theta_attr_names):
180
+ for el in self ._attr_names :
181
+ res [el ] = np .zeros ((nb_sample , self .data [el ].shape [1 ]), dtype = self .data [el ].dtype )
182
+
183
+ #for el in (*self._attr_names, *self._theta_attr_names):
184
+ for el in self ._attr_names :
185
+ res [el ][:] = self .data [el ][index , :]
186
+
187
+ return res
188
+
189
+ if __name__ == '__main__' :
190
+ import math
191
+ from lips .physical_simulator .getfemSimulator import GetfemSimulator
192
+ physicalDomain = {
193
+ "Mesher" :"Getfem" ,
194
+ "refNumByRegion" :{"HOLE_BOUND" : 1 ,"CONTACT_BOUND" : 2 , "EXTERIOR_BOUND" : 3 },
195
+ "wheelDimensions" :(8. ,15. ),
196
+ "meshSize" :1
197
+ }
198
+
199
+ physicalProperties = {
200
+ "ProblemType" :"StaticMechanicalStandard" ,
201
+ "materials" :[["ALL" , {"law" :"LinearElasticity" ,"young" :21E6 ,"poisson" :0.3 } ]],
202
+ "sources" :[["ALL" ,{"type" : "Uniform" ,"source_x" :0.0 ,"source_y" :0 }] ],
203
+ "dirichlet" :[["HOLE_BOUND" ,{"type" : "scalar" , "Disp_Amplitude" :6 , "Disp_Angle" :- math .pi / 2 }] ],
204
+ "contact" :[ ["CONTACT_BOUND" ,{"type" : "Plane" ,"gap" :2.0 ,"fricCoeff" :0.9 }] ]
205
+ }
206
+ training_simulator = GetfemSimulator (physicalDomain = physicalDomain ,physicalProperties = physicalProperties )
207
+
208
+ from lips .dataset .sampler import LHSSampler
209
+ trainingInput = {
210
+ "young" :(75.0 ,85.0 ),
211
+ "poisson" :(0.38 ,0.44 ),
212
+ "fricCoeff" :(0.5 ,0.8 )
213
+ }
214
+
215
+ training_actor = LHSSampler (space_params = trainingInput )
216
+ nb_sample_train = 10
217
+ path_datasets = "TotoDir"
218
+
219
+ import lips .physical_simulator .GetfemSimulator .PhysicalFieldNames as PFN
220
+ attr_names = (PFN .displacement ,PFN .contactMultiplier )
221
+
222
+ rollingWheelDataSet = RollingWheelDataSet ("train" ,attr_names = attr_names )
223
+ rollingWheelDataSet .generate (simulator = training_simulator ,
224
+ actor = training_actor ,
225
+ path_out = path_datasets ,
226
+ nb_samples = nb_sample_train ,
227
+ actor_seed = 42
228
+ )
229
+ print (rollingWheelDataSet .get_data (index = 0 ))
230
+ print (rollingWheelDataSet .data )
0 commit comments