Skip to content

Commit 9e1c6a0

Browse files
Dd dev (#63) Data & Physics simulator for Michelin use case
* Add all script related to pure getfem simulator * Build bridge between LIPS and Getfem * Remove tests in Getfem repository, add simple test to check bridge * Add sampler for parameter space discretization * Enable interface features to be used from simulator interface: retrieve fields * Generate database using getfem simulator * Define potential seed for sampling Co-authored-by: David Danan <david.danan@irt-systemx.fr>
1 parent 6b409d0 commit 9e1c6a0

13 files changed

+2854
-4
lines changed

.github/workflows/gitlab_sync.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@ on:
55
branches:
66
- main
77
- 'ml-*'
8-
- 'bd-dev'
8+
- 'bd-*'
99
- 'dd-*'
1010
- 'jp-*'
1111
- 'am-*'
1212
delete:
1313
branches:
1414
- main
1515
- 'ml-*'
16-
- 'bd-dev'
16+
- 'bd-*'
1717
- 'dd-*'
1818
- 'jp-*'
1919
- 'am-*'
2020
pull_request:
2121
branches:
2222
- main
2323
- 'ml-*'
24-
- 'bd-dev'
24+
- 'bd-*'
2525
- 'dd-*'
2626
- 'jp-*'
2727
- 'am-*'

.github/workflows/pylint.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ on:
55
branches:
66
- main
77
- 'ml-*'
8-
- 'bd-dev'
8+
- 'bd-*'
99
- 'dd-*'
1010
- 'jp-*'
1111
- 'am-*'

lips/dataset/rollingWheelDataSet.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
import warnings
6+
import numpy as np
7+
import shutil
8+
9+
import copy
10+
from typing import Union
11+
from tqdm import tqdm # TODO remove for final push
12+
13+
from lips.dataset.dataSet import DataSet
14+
from lips.logger import CustomLogger
15+
16+
class RollingWheelDataSet(DataSet):
17+
"""
18+
This specific DataSet uses Getfem framework to simulate data arising from a rolling wheel problem.
19+
"""
20+
21+
def __init__(self,
22+
name="train",
23+
attr_names=("disp",),
24+
log_path: Union[str, None]=None
25+
):
26+
DataSet.__init__(self, name=name)
27+
self._attr_names = copy.deepcopy(attr_names)
28+
self.size = 0
29+
30+
# logger
31+
self.logger = CustomLogger(__class__.__name__, log_path).logger
32+
33+
def generate(self,
34+
simulator: "GetfemSimulator",
35+
actor,
36+
path_out,
37+
nb_samples,
38+
simulator_seed: Union[None, int] = None,
39+
actor_seed: Union[None, int] = None):
40+
"""
41+
For this dataset, we use a GetfemSimulator and a Sampler to generate data from a rolling wheel.
42+
43+
Parameters
44+
----------
45+
simulator:
46+
In this case, this should be a grid2op environment
47+
48+
actor:
49+
In this case, it is the sampler used for the input parameters space discretization
50+
51+
path_out:
52+
The path where the data will be saved
53+
54+
nb_samples:
55+
Number of rows (examples) in the final dataset
56+
57+
simulator_seed:
58+
Seed used to set the simulator for reproducible experiments
59+
60+
actor_seed:
61+
Seed used to set the actor for reproducible experiments
62+
63+
Returns
64+
-------
65+
66+
"""
67+
try:
68+
import getfem
69+
except ImportError as exc_:
70+
raise RuntimeError("Impossible to `generate` rolling wheel datet if you don't have "
71+
"the getfem package installed") from exc_
72+
if nb_samples <= 0:
73+
raise RuntimeError("Impossible to generate a negative number of data.")
74+
75+
samples=actor.generate_samples(nb_samples=nb_samples,sampler_seed=actor_seed)
76+
self._init_store_data(simulator=simulator,nb_samples=nb_samples)
77+
78+
for current_size,sample in enumerate(tqdm(samples, desc=self.name)):
79+
simulator=type(simulator)(simulatorInstance=simulator)
80+
simulator.modify_state(actor=sample)
81+
simulator.build_model()
82+
solverState=simulator.run_problem()
83+
84+
self._store_obs(current_size=current_size,obs=simulator)
85+
86+
self.size = nb_samples
87+
88+
if path_out is not None:
89+
# I should save the data
90+
self._save_internal_data(path_out)
91+
92+
def _init_store_data(self,simulator,nb_samples):
93+
self.data=dict()
94+
for attr_nm in self._attr_names:
95+
array_ = simulator.get_variable_value(field_name=attr_nm)
96+
truc= np.zeros((nb_samples, array_.shape[0]))
97+
self.data[attr_nm] = np.zeros((nb_samples, array_.shape[0]), dtype=array_.dtype)
98+
99+
def _store_obs(self, current_size, obs):
100+
for attr_nm in self._attr_names:
101+
array_ = obs.get_solution(field_name=attr_nm)
102+
self.data[attr_nm][current_size, :] = array_
103+
104+
def _save_internal_data(self, path_out):
105+
"""save the self.data in a proper format"""
106+
full_path_out = os.path.join(os.path.abspath(path_out), self.name)
107+
108+
if not os.path.exists(os.path.abspath(path_out)):
109+
os.mkdir(os.path.abspath(path_out))
110+
# TODO logger
111+
#print(f"Creating the path {path_out} to store the datasets [data will be stored under {full_path_out}]")
112+
self.logger.info(f"Creating the path {path_out} to store the datasets [data will be stored under {full_path_out}]")
113+
114+
if os.path.exists(full_path_out):
115+
# deleting previous saved data
116+
# TODO logger
117+
#print(f"Deleting previous run at {full_path_out}")
118+
self.logger.warning(f"Deleting previous run at {full_path_out}")
119+
shutil.rmtree(full_path_out)
120+
121+
os.mkdir(full_path_out)
122+
# TODO logger
123+
#print(f"Creating the path {full_path_out} to store the dataset name {self.name}")
124+
self.logger.info(f"Creating the path {full_path_out} to store the dataset name {self.name}")
125+
126+
for attr_nm in self._attr_names:
127+
np.savez_compressed(f"{os.path.join(full_path_out, attr_nm)}.npz", data=self.data[attr_nm])
128+
129+
def load(self, path):
130+
if not os.path.exists(path):
131+
raise RuntimeError(f"{path} cannot be found on your computer")
132+
if not os.path.isdir(path):
133+
raise RuntimeError(f"{path} is not a valid directory")
134+
full_path = os.path.join(path, self.name)
135+
if not os.path.exists(full_path):
136+
raise RuntimeError(f"There is no data saved in {full_path}. Have you called `dataset.generate()` with "
137+
f"a given `path_out` ?")
138+
#for attr_nm in (*self._attr_names, *self._theta_attr_names):
139+
for attr_nm in self._attr_names:
140+
path_this_array = f"{os.path.join(full_path, attr_nm)}.npz"
141+
if not os.path.exists(path_this_array):
142+
raise RuntimeError(f"Impossible to load data {attr_nm}. Have you called `dataset.generate()` with "
143+
f"a given `path_out` and such that `dataset` is built with the right `attr_names` ?")
144+
145+
if self.data is not None:
146+
warnings.warn(f"Deleting previous run in attempting to load the new one located at {path}")
147+
self.data = {}
148+
self.size = None
149+
#for attr_nm in (*self._attr_names, *self._theta_attr_names):
150+
for attr_nm in self._attr_names:
151+
path_this_array = f"{os.path.join(full_path, attr_nm)}.npz"
152+
self.data[attr_nm] = np.load(path_this_array)["data"]
153+
self.size = self.data[attr_nm].shape[0]
154+
155+
def get_data(self, index):
156+
"""
157+
This function returns the data in the data that match the index `index`
158+
159+
Parameters
160+
----------
161+
index:
162+
A list of integer
163+
164+
Returns
165+
-------
166+
167+
"""
168+
super().get_data(index) # check that everything is legit
169+
170+
# make sure the index are numpy array
171+
if isinstance(index, list):
172+
index = np.array(index, dtype=int)
173+
elif isinstance(index, int):
174+
index = np.array([index], dtype=int)
175+
176+
# init the results
177+
res = {}
178+
nb_sample = index.size
179+
#for el in (*self._attr_names, *self._theta_attr_names):
180+
for el in self._attr_names:
181+
res[el] = np.zeros((nb_sample, self.data[el].shape[1]), dtype=self.data[el].dtype)
182+
183+
#for el in (*self._attr_names, *self._theta_attr_names):
184+
for el in self._attr_names:
185+
res[el][:] = self.data[el][index, :]
186+
187+
return res
188+
189+
if __name__ == '__main__':
190+
import math
191+
from lips.physical_simulator.getfemSimulator import GetfemSimulator
192+
physicalDomain={
193+
"Mesher":"Getfem",
194+
"refNumByRegion":{"HOLE_BOUND": 1,"CONTACT_BOUND": 2, "EXTERIOR_BOUND": 3},
195+
"wheelDimensions":(8.,15.),
196+
"meshSize":1
197+
}
198+
199+
physicalProperties={
200+
"ProblemType":"StaticMechanicalStandard",
201+
"materials":[["ALL", {"law":"LinearElasticity","young":21E6,"poisson":0.3} ]],
202+
"sources":[["ALL",{"type" : "Uniform","source_x":0.0,"source_y":0}] ],
203+
"dirichlet":[["HOLE_BOUND",{"type" : "scalar", "Disp_Amplitude":6, "Disp_Angle":-math.pi/2}] ],
204+
"contact":[ ["CONTACT_BOUND",{"type" : "Plane","gap":2.0,"fricCoeff":0.9}] ]
205+
}
206+
training_simulator=GetfemSimulator(physicalDomain=physicalDomain,physicalProperties=physicalProperties)
207+
208+
from lips.dataset.sampler import LHSSampler
209+
trainingInput={
210+
"young":(75.0,85.0),
211+
"poisson":(0.38,0.44),
212+
"fricCoeff":(0.5,0.8)
213+
}
214+
215+
training_actor=LHSSampler(space_params=trainingInput)
216+
nb_sample_train=10
217+
path_datasets="TotoDir"
218+
219+
import lips.physical_simulator.GetfemSimulator.PhysicalFieldNames as PFN
220+
attr_names=(PFN.displacement,PFN.contactMultiplier)
221+
222+
rollingWheelDataSet=RollingWheelDataSet("train",attr_names=attr_names)
223+
rollingWheelDataSet.generate(simulator=training_simulator,
224+
actor=training_actor,
225+
path_out=path_datasets,
226+
nb_samples=nb_sample_train,
227+
actor_seed=42
228+
)
229+
print(rollingWheelDataSet.get_data(index=0))
230+
print(rollingWheelDataSet.data)

lips/dataset/sampler.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
#This file introduce the sampling methods used to generate a space of parameters
5+
import abc
6+
import numpy as np
7+
import csv
8+
import pyDOE2 as doe
9+
10+
class Sampler(metaclass=abc.ABCMeta):
11+
def __init__(self,space_params):
12+
self.space_params=space_params
13+
self.sampling_output=[]
14+
self.sampling_name=""
15+
16+
def generate_samples(self,nb_samples,sampler_seed=None):
17+
self.sampling_output=self._define_sampling_method(nb_samples=nb_samples,sampler_seed=sampler_seed)
18+
return self.sampling_output
19+
20+
@abc.abstractmethod
21+
def _define_sampling_method(self,nb_samples,sampler_seed=None):
22+
pass
23+
24+
def save_samples_in_file(self,filename,samples=None):
25+
if samples is None:
26+
samples=self.sampling_output
27+
28+
fieldNum=[len(samples[0].keys()) for sample in samples]
29+
if fieldNum.count(fieldNum[0]) != len(fieldNum):
30+
raise RuntimeError("Samples do not have the same input parameters")
31+
32+
with open(filename, mode='w') as csv_file:
33+
fieldnames = list(samples[0].keys())
34+
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
35+
writer.writeheader()
36+
for paramsSet in samples:
37+
writer.writerow(paramsSet)
38+
39+
def __str__(self):
40+
sInfo="Type of sampling: "+self.sampling_name+"\n"
41+
sInfo+="Parameters\n"
42+
for paramName,paramVal in self.space_params.items():
43+
sInfo+="\t"+str(paramName)+": "+str(paramVal)+"\n"
44+
return sInfo
45+
46+
def __len__(self):
47+
return len(self.sampling_output)
48+
49+
50+
class LHSSampler(Sampler):
51+
def __init__(self, space_params):
52+
super(LHSSampler,self).__init__(space_params=space_params)
53+
self.sampling_name="LHSSampler"
54+
55+
def _define_sampling_method(self,nb_samples,sampler_seed=None):
56+
space_params=self.space_params
57+
nfactor = len(space_params)
58+
self.vals =doe.lhs(nfactor, samples=nb_samples, random_state=sampler_seed, criterion="maximin")
59+
60+
vals=np.transpose(self.vals)
61+
paramsVectByName = {}
62+
for i,paramName in enumerate(space_params.keys()):
63+
minVal,maxVal=space_params[paramName]
64+
paramsVectByName[paramName] = minVal + vals[i]*(maxVal - minVal)
65+
return list(map(dict, zip(*[[(k, v) for v in value] for k, value in paramsVectByName.items()])))
66+
67+
if __name__ =="__main__":
68+
params={"params1":(21E6,22E6),"params2":(0.2,0.3)}
69+
sample_params=LHSSampler(space_params=params)
70+
test_params1=sample_params.generate_samples(nb_samples=2,sampler_seed=42)
71+
test_params2=sample_params.generate_samples(nb_samples=2,sampler_seed=42)
72+
assert test_params1==test_params2
73+
74+
params={"params1":(21E6,22E6),"params2":(0.2,0.3)}
75+
sample_params=LHSSampler(space_params=params)
76+
nb_samples=5
77+
test_params=sample_params.generate_samples(nb_samples=nb_samples)
78+
assert len(test_params)==nb_samples
79+
print(sample_params)
80+

0 commit comments

Comments
 (0)