-
Notifications
You must be signed in to change notification settings - Fork 68
/
_config_import.py
467 lines (421 loc) · 16.3 KB
/
_config_import.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
import ast
import copy
import difflib
import importlib
import os
import pathlib
from dataclasses import field
from functools import partial
from types import SimpleNamespace
import matplotlib
import mne
import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError
from ._logging import gen_log_kwargs, logger
from .typing import PathLike
def _import_config(
*,
config_path: PathLike | None,
overrides: SimpleNamespace | None = None,
check: bool = True,
log: bool = True,
) -> SimpleNamespace:
"""Import the default config and the user's config."""
# Get the default
config = _get_default_config()
# Public names users generally will have in their config
valid_names = [d for d in dir(config) if not d.startswith("_")]
# Names that we will reduce the SimpleConfig to before returning
# (see _update_with_user_config)
keep_names = [d for d in dir(config) if not d.startswith("__")] + [
"config_path",
"PIPELINE_NAME",
"VERSION",
"CODE_URL",
"_raw_split_size",
"_epochs_split_size",
]
# Update with user config
user_names = _update_with_user_config(
config=config,
config_path=config_path,
overrides=overrides,
log=log,
)
extra_exec_params_keys = ()
extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "")
if extra_config:
msg = f"With testing config: {extra_config}"
logger.info(**gen_log_kwargs(message=msg, emoji="override"))
_update_config_from_path(
config=config,
config_path=extra_config,
)
extra_exec_params_keys = ("_n_jobs",)
keep_names.extend(extra_exec_params_keys)
# Check it
if check:
_check_config(config, config_path)
_check_misspellings_removals(
valid_names=valid_names,
user_names=user_names,
log=log,
config_validation=config.config_validation,
)
# Finally, reduce to our actual supported params (all keep_names should be present)
config = SimpleNamespace(**{k: getattr(config, k) for k in keep_names})
# Take some standard actions
mne.set_log_level(verbose=config.mne_log_level.upper())
# Take variables out of config (which affects the pipeline outputs) and
# put into config.exec_params (which affect the pipeline execution methods,
# but not the outputs)
keys = (
# Parallelization
"n_jobs",
"parallel_backend",
"dask_temp_dir",
"dask_worker_memory_limit",
"dask_open_dashboard",
# Interaction
"on_error",
"interactive",
# Caching
"memory_location",
"memory_subdir",
"memory_verbose",
"memory_file_method",
# Misc
"deriv_root",
"config_path",
) + extra_exec_params_keys
in_both = {"deriv_root"}
exec_params = SimpleNamespace(**{k: getattr(config, k) for k in keys})
for k in keys:
if k not in in_both:
delattr(config, k)
config.exec_params = exec_params
return config
def _get_default_config():
from . import _config
# Don't use _config itself as it's mutable -- make a new object
# with deepcopies of vals (keys are immutable strings so no need to copy)
# except modules and imports
tree = ast.parse(pathlib.Path(_config.__file__).read_text())
ignore_keys = {
name.asname or name.name
for element in tree.body
if isinstance(element, ast.Import | ast.ImportFrom)
for name in element.names
}
config = SimpleNamespace(
**{
key: copy.deepcopy(val)
for key, val in _config.__dict__.items()
if not (key.startswith("__") or key in ignore_keys)
}
)
return config
def _update_config_from_path(
*,
config: SimpleNamespace,
config_path: PathLike,
):
user_names = list()
config_path = pathlib.Path(config_path).expanduser().resolve(strict=True)
# Import configuration from an arbitrary path without having to fiddle
# with `sys.path`.
spec = importlib.util.spec_from_file_location(
name="custom_config", location=config_path
)
custom_cfg = importlib.util.module_from_spec(spec)
spec.loader.exec_module(custom_cfg)
for key in dir(custom_cfg):
if not key.startswith("__"):
# don't validate private vars, but do add to config
# (e.g., so that our hidden _raw_split_size is included)
if not key.startswith("_"):
user_names.append(key)
val = getattr(custom_cfg, key)
logger.debug(f"Overwriting: {key} -> {val}")
setattr(config, key, val)
return user_names
def _update_with_user_config(
*,
config: SimpleNamespace, # modified in-place
config_path: PathLike | None,
overrides: SimpleNamespace | None,
log: bool = False,
) -> list[str]:
# 1. Basics and hidden vars
from . import __version__
config.PIPELINE_NAME = "mne-bids-pipeline"
config.VERSION = __version__
config.CODE_URL = "https://github.com/mne-tools/mne-bids-pipeline"
config._raw_split_size = "2GB"
config._epochs_split_size = "2GB"
# 2. User config
user_names = list()
if config_path is not None:
user_names.extend(
_update_config_from_path(
config=config,
config_path=config_path,
)
)
config.config_path = config_path
# 3. Overrides via command-line switches
overrides = overrides or SimpleNamespace()
for name in dir(overrides):
if not name.startswith("__"):
val = getattr(overrides, name)
if log:
msg = f"Overriding config.{name} = {repr(val)}"
logger.info(**gen_log_kwargs(message=msg, emoji="override"))
setattr(config, name, val)
# 4. Env vars and other triaging
if not config.bids_root:
root = os.getenv("BIDS_ROOT", None)
if root is None:
raise ValueError(
"You need to specify `bids_root` in your configuration, or "
"define an environment variable `BIDS_ROOT` pointing to the "
"root folder of your BIDS dataset"
)
config.bids_root = root
config.bids_root = pathlib.Path(config.bids_root).expanduser().resolve()
if config.deriv_root is None:
config.deriv_root = config.bids_root / "derivatives" / config.PIPELINE_NAME
config.deriv_root = pathlib.Path(config.deriv_root).expanduser().resolve()
# 5. Consistency
log_kwargs = dict(emoji="override")
if config.interactive:
if log and config.on_error != "debug":
msg = 'Setting config.on_error="debug" because of interactive mode'
logger.info(**gen_log_kwargs(message=msg, **log_kwargs))
config.on_error = "debug"
else:
matplotlib.use("Agg") # do not open any window # noqa
if config.on_error == "debug":
if log and config.n_jobs != 1:
msg = 'Setting config.n_jobs=1 because config.on_error="debug"'
logger.info(**gen_log_kwargs(message=msg, **log_kwargs))
config.n_jobs = 1
if log and config.parallel_backend != "loky":
msg = (
'Setting config.parallel_backend="loky" because '
'config.on_error="debug"'
)
logger.info(**gen_log_kwargs(message=msg, **log_kwargs))
config.parallel_backend = "loky"
return user_names
def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None:
_pydantic_validate(config=config, config_path=config_path)
# Eventually all of these could be pydantic-validated, but for now we'll
# just change the ones that are easy
config.bids_root.resolve(strict=True)
if (
config.use_maxwell_filter
and len(set(config.ch_types).intersection(("meg", "grad", "mag"))) == 0
):
raise ValueError("Cannot use Maxwell filter without MEG channels.")
reject = config.reject
ica_reject = config.ica_reject
if config.spatial_filter == "ica":
if config.ica_l_freq < 1:
raise ValueError(
"You requested to high-pass filter the data before ICA with "
f"ica_l_freq={config.ica_l_freq} Hz. Please increase this "
"setting to 1 Hz or above to ensure reliable ICA function."
)
if (
config.ica_l_freq is not None
and config.l_freq is not None
and config.ica_l_freq < config.l_freq
):
raise ValueError(
"You requested a lower high-pass filter cutoff frequency for "
f"ICA than for your raw data: ica_l_freq = {config.ica_l_freq}"
f" < l_freq = {config.l_freq}. Adjust the cutoffs such that "
"ica_l_freq >= l_freq, or set ica_l_freq to None if you do "
"not wish to apply an additional high-pass filter before "
"running ICA."
)
if (
ica_reject is not None
and reject is not None
and reject not in ["autoreject_global", "autoreject_local"]
):
for ch_type in reject:
if ch_type in ica_reject and reject[ch_type] > ica_reject[ch_type]:
raise ValueError(
f'Rejection threshold in reject["{ch_type}"] '
f"({reject[ch_type]}) must be at least as stringent "
"as that in "
f'ica_reject["{ch_type}"] ({ica_reject[ch_type]})'
)
if config.noise_cov == "emptyroom" and "eeg" in config.ch_types:
raise ValueError(
"You requested to process data that contains EEG channels. In "
"this case, noise covariance can only be estimated from the "
"experimental data, e.g., the pre-stimulus period. Please set "
"noise_cov to (tmin, tmax)"
)
if config.noise_cov == "emptyroom" and not config.process_empty_room:
raise ValueError(
"You requested noise covariance estimation from empty-room "
'recordings by setting noise_cov = "emptyroom", but you did not '
"enable empty-room data processing. "
"Please set process_empty_room = True"
)
bl = config.baseline
if bl is not None:
if (bl[0] is not None and bl[0] < config.epochs_tmin) or (
bl[1] is not None and bl[1] > config.epochs_tmax
):
raise ValueError(
f"baseline {bl} outside of epochs interval "
f"{[config.epochs_tmin, config.epochs_tmax]}."
)
if bl[0] is not None and bl[1] is not None and bl[0] >= bl[1]:
raise ValueError(
f"The end of the baseline period must occur after its start, "
f"but you set baseline={bl}"
)
# check cluster permutation parameters
if config.cluster_n_permutations < 10 / config.cluster_permutation_p_threshold:
raise ValueError(
"cluster_n_permutations is not big enough to calculate "
"the p-values accurately."
)
# Another check that depends on some of the functions defined above
if not config.task_is_rest and config.conditions is None:
raise ValueError(
"Please indicate the name of your conditions in your "
"configuration. Currently the `conditions` parameter is empty. "
"This is only allowed for resting-state analysis."
)
if not isinstance(config.mf_destination, str):
destination = np.array(config.mf_destination, float)
if destination.shape != (4, 4):
raise ValueError(
"config.mf_destination, if array-like, must have shape (4, 4) "
f"but got shape {destination.shape}"
)
def _default_factory(key, val):
# convert a default to a default factory if needed, having an explicit
# allowlist of non-empty ones
allowlist = [
{"n_mag": 1, "n_grad": 1, "n_eeg": 1}, # n_proj_*
{"custom": (8, 24.0, 40)}, # decoding_csp_freqs
{"suffix": "ave"}, # source_info_path_update
["evoked"], # inverse_targets
[4, 8, 16], # autoreject_n_interpolate
]
for typ in (dict, list):
if isinstance(val, typ):
try:
idx = allowlist.index(val)
except ValueError:
assert val == typ(), (key, val)
default_factory = typ
else:
if typ is dict:
default_factory = partial(typ, **allowlist[idx])
else:
assert typ is list
default_factory = partial(typ, allowlist[idx])
return field(default_factory=default_factory)
return val
def _pydantic_validate(
config: SimpleNamespace,
config_path: PathLike | None,
):
"""Create dataclass from config type hints and validate with pydantic."""
# https://docs.pydantic.dev/latest/usage/dataclasses/
from . import _config as root_config
# Modify annotations to add nested strict parsing
annotations = dict()
attrs = dict()
for key, annot in root_config.__annotations__.items():
annotations[key] = annot
attrs[key] = _default_factory(key, root_config.__dict__[key])
name = "user configuration"
if config_path is not None:
name += f" from {config_path}"
model_config = ConfigDict(
arbitrary_types_allowed=True, # needed in 2.6.0 to allow DigMontage for example
validate_assignment=True,
strict=True, # do not allow float for int for example
extra="forbid",
)
UserConfig = type(
name,
(BaseModel,),
{"__annotations__": annotations, "model_config": model_config, **attrs},
)
# Now use pydantic to automagically validate
user_vals = {key: val for key, val in config.__dict__.items() if key in annotations}
try:
UserConfig.model_validate(user_vals)
except ValidationError as err:
raise ValueError(str(err)) from None
_REMOVED_NAMES = {
"debug": dict(
new_name="on_error",
instead='use on_error="debug" instead',
),
"decim": dict(
new_name="epochs_decim",
instead=None,
),
"resample_sfreq": dict(
new_name="raw_resample_sfreq",
),
"N_JOBS": dict(
new_name="n_jobs",
),
}
def _check_misspellings_removals(
*,
valid_names: list[str],
user_names: list[str],
log: bool,
config_validation: str,
) -> None:
# for each name in the user names, check if it's in the valid names but
# the correct one is not defined
valid_names = set(valid_names)
for user_name in user_names:
if user_name not in valid_names:
# find the closest match
closest_match = difflib.get_close_matches(user_name, valid_names, n=1)
msg = f"Found a variable named {repr(user_name)} in your custom config,"
if closest_match and closest_match[0] not in user_names:
this_msg = (
f"{msg} did you mean {repr(closest_match[0])}? "
"If so, please correct the error. If not, please rename "
"the variable to reduce ambiguity and avoid this message, "
"or set config.config_validation to 'warn' or 'ignore'."
)
_handle_config_error(this_msg, log, config_validation)
if user_name in _REMOVED_NAMES:
new = _REMOVED_NAMES[user_name]["new_name"]
if new not in user_names:
instead = _REMOVED_NAMES[user_name].get("instead", None)
if instead is None:
instead = f"use {new} instead"
this_msg = (
f"{msg} this variable has been removed as a valid "
f"config option, {instead}."
)
_handle_config_error(this_msg, log, config_validation)
def _handle_config_error(
msg: str,
log: bool,
config_validation: str,
) -> None:
if config_validation == "raise":
raise ValueError(msg)
elif config_validation == "warn":
if log:
logger.warning(**gen_log_kwargs(message=msg, emoji="🛟"))