Skip to content

Commit

Permalink
Improve preprocessing utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Jan 27, 2024
1 parent 7be2906 commit 80c1fd0
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 74 deletions.
80 changes: 49 additions & 31 deletions weaver/utils/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(self, print_info=True, **kwargs):
if print_info:
_logger.debug(opts)

self.train_load_branches = set()
self.train_aux_branches = set()
self.test_load_branches = set()
self.test_aux_branches = set()

self.selection = opts['selection']
self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection
self.var_funcs = copy.deepcopy(opts['new_variables'])
Expand Down Expand Up @@ -101,26 +106,27 @@ def _get(idx, default):
assert (isinstance(self.label_value, list))
self.label_names = ('_label_',)
label_exprs = ['ak.to_numpy(%s)' % k for k in self.label_value]
self.var_funcs['_label_'] = 'np.argmax(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs))
self.var_funcs['_labelcheck_'] = 'np.sum(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs))
self.register('_label_', 'np.argmax(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)))
self.register('_labelcheck_', 'np.sum(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)), 'train')
else:
self.label_names = tuple(self.label_value.keys())
self.var_funcs.update(self.label_value)
self.register(self.label_value)
self.basewgt_name = '_basewgt_'
self.weight_name = None
if opts['weights'] is not None:
self.weight_name = 'weight_'
self.weight_name = '_weight_'
self.use_precomputed_weights = opts['weights']['use_precomputed_weights']
if self.use_precomputed_weights:
self.var_funcs[self.weight_name] = '*'.join(opts['weights']['weight_branches'])
self.register(self.weight_name, '*'.join(opts['weights']['weight_branches']), 'train')
else:
self.reweight_method = opts['weights']['reweight_method']
self.reweight_basewgt = opts['weights'].get('reweight_basewgt', None)
if self.reweight_basewgt:
self.var_funcs[self.basewgt_name] = self.reweight_basewgt
self.register(self.basewgt_name, self.reweight_basewgt, 'train')
self.reweight_branches = tuple(opts['weights']['reweight_vars'].keys())
self.reweight_bins = tuple(opts['weights']['reweight_vars'].values())
self.reweight_classes = tuple(opts['weights']['reweight_classes'])
self.register(self.reweight_branches + self.reweight_classes, to='train')
self.class_weights = opts['weights'].get('class_weights', None)
if self.class_weights is None:
self.class_weights = np.ones(len(self.reweight_classes))
Expand Down Expand Up @@ -167,44 +173,56 @@ def _log(msg, *args, **kwargs):
'reweight_discard_under_overflow']:
_log('%s: %s' % (k, getattr(self, k)))

# parse config
self.keep_branches = set()
aux_branches = set()
# selection
if self.selection:
aux_branches.update(_get_variable_names(self.selection))
self.register(_get_variable_names(self.selection), to='train')
# test time selection
if self.test_time_selection:
aux_branches.update(_get_variable_names(self.test_time_selection))
# var_funcs
self.keep_branches.update(self.var_funcs.keys())
for expr in self.var_funcs.values():
aux_branches.update(_get_variable_names(expr))
self.register(_get_variable_names(self.test_time_selection), to='test')
# inputs
for names in self.input_dicts.values():
self.keep_branches.update(names)
# labels
self.keep_branches.update(self.label_names)
# weight
if self.weight_name:
self.keep_branches.add(self.weight_name)
if not self.use_precomputed_weights:
aux_branches.update(self.reweight_branches)
aux_branches.update(self.reweight_classes)
self.register(names)
# observers
self.keep_branches.update(self.observer_names)
self.register(self.observer_names, to='test')
# monitor variables
self.keep_branches.update(self.monitor_variables)
# keep and drop
self.drop_branches = (aux_branches - self.keep_branches)
self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) - {self.weight_name, }
self.register(self.monitor_variables)
# resolve dependencies
func_vars = set(self.var_funcs.keys())
for (load_branches, aux_branches) in (self.train_load_branches, self.train_aux_branches), (self.test_load_branches, self.test_aux_branches):
while (load_branches & func_vars):
for k in (load_branches & func_vars):
aux_branches.add(k)
load_branches.remove(k)
load_branches.update(_get_variable_names(self.var_funcs[k]))
if print_info:
_logger.debug('drop_branches:\n %s', ','.join(self.drop_branches))
_logger.debug('load_branches:\n %s', ','.join(self.load_branches))
_logger.debug('train_load_branches:\n %s', ', '.join(sorted(self.train_load_branches)))
_logger.debug('train_aux_branches:\n %s', ', '.join(sorted(self.train_aux_branches)))
_logger.debug('test_load_branches:\n %s', ', '.join(sorted(self.test_load_branches)))
_logger.debug('test_aux_branches:\n %s', ', '.join(sorted(self.test_aux_branches)))

def __getattr__(self, name):
return self.options[name]

def register(self, name, expr=None, to='both'):
assert to in ('train', 'test', 'both')
if isinstance(name, dict):
for k, v in name.items():
self.register(k, v, to)
elif isinstance(name, (list, tuple)):
for k in name:
self.register(k, None, to)
else:
if to in ('train', 'both'):
self.train_load_branches.add(name)
if to in ('test', 'both'):
self.test_load_branches.add(name)
if expr:
self.var_funcs[name] = expr
if to in ('train', 'both'):
self.train_aux_branches.add(name)
if to in ('test', 'both'):
self.test_aux_branches.add(name)

def dump(self, fp):
with open(fp, 'w') as f:
yaml.safe_dump(self.options, f, sort_keys=False)
Expand Down
89 changes: 48 additions & 41 deletions weaver/utils/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from .fileio import _read_files


def _apply_selection(table, selection, funcs={}):
def _apply_selection(table, selection, funcs=None):
if selection is None:
return table
new_vars = {k: funcs[k] for k in _get_variable_names(selection) if k not in table.fields and k in funcs}
_build_new_variables(table, new_vars)
if funcs:
new_vars = {k: funcs[k] for k in _get_variable_names(selection) if k not in table.fields and k in funcs}
_build_new_variables(table, new_vars)
selected = ak.values_astype(_eval_expr(selection, table), 'bool')
return table[selected]

Expand All @@ -28,11 +29,6 @@ def _build_new_variables(table, funcs):
return table


def _clean_up(table, drop_branches):
columns = [k for k in table.fields if k not in drop_branches]
return table[columns]


def _build_weights(table, data_config, reweight_hists=None):
if data_config.weight_name is None:
raise RuntimeError('Error when building weights: `weight_name` is None!')
Expand Down Expand Up @@ -92,27 +88,33 @@ def __init__(self, filelist, data_config):
self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1))

def read_file(self, filelist):
self.keep_branches = set()
self.load_branches = set()
keep_branches = set()
aux_branches = set()
load_branches = set()
for k, params in self._data_config.preprocess_params.items():
if params['center'] == 'auto':
self.keep_branches.add(k)
if k in self._data_config.var_funcs:
expr = self._data_config.var_funcs[k]
self.load_branches.update(_get_variable_names(expr))
else:
self.load_branches.add(k)
keep_branches.add(k)
load_branches.add(k)
if self._data_config.selection:
self.load_branches.update(_get_variable_names(self._data_config.selection))
_logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(self.keep_branches))
_logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(self.load_branches))
table = _read_files(filelist, self.load_branches, self.load_range, show_progressbar=True,
load_branches.update(_get_variable_names(self._data_config.selection))

func_vars = set(self._data_config.var_funcs.keys())
while (load_branches & func_vars):
for k in (load_branches & func_vars):
aux_branches.add(k)
load_branches.remove(k)
load_branches.update(_get_variable_names(self._data_config.var_funcs[k]))

_logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(keep_branches))
_logger.debug('[AutoStandardizer] aux_branches:\n %s', ','.join(aux_branches))
_logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(load_branches))

table = _read_files(filelist, load_branches, self.load_range, show_progressbar=True,
treename=self._data_config.treename,
branch_magic=self._data_config.branch_magic, file_magic=self._data_config.file_magic)
table = _apply_selection(table, self._data_config.selection, funcs=self._data_config.var_funcs)
table = _build_new_variables(
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
table = _clean_up(table, self.load_branches - self.keep_branches)
table = _build_new_variables(table, {k: v for k, v in self._data_config.var_funcs.items() if k in aux_branches})
table = table[keep_branches]
return table

def make_preprocess_params(self, table):
Expand Down Expand Up @@ -142,7 +144,7 @@ def produce(self, output=None):
table = self.read_file(self._filelist)
preprocess_params = self.make_preprocess_params(table)
self._data_config.preprocess_params = preprocess_params
# must also propogate the changes to `data_config.options` so it can be persisted
# must also propagate the changes to `data_config.options` so it can be persisted
self._data_config.options['preprocess']['params'] = preprocess_params
if output:
_logger.info(
Expand All @@ -168,26 +170,31 @@ def __init__(self, filelist, data_config):
self._data_config = data_config.copy()

def read_file(self, filelist):
self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes +
(self._data_config.basewgt_name,))
self.load_branches = set()
for k in self.keep_branches:
if k in self._data_config.var_funcs:
expr = self._data_config.var_funcs[k]
self.load_branches.update(_get_variable_names(expr))
else:
self.load_branches.add(k)
keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes)
if self._data_config.reweight_basewgt:
keep_branches.add(self._data_config.basewgt_name)
aux_branches = set()
load_branches = keep_branches.copy()
if self._data_config.selection:
self.load_branches.update(_get_variable_names(self._data_config.selection))
_logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(self.keep_branches))
_logger.debug('[WeightMaker] load_branches:\n %s', ','.join(self.load_branches))
table = _read_files(filelist, self.load_branches, show_progressbar=True,
load_branches.update(_get_variable_names(self._data_config.selection))

func_vars = set(self._data_config.var_funcs.keys())
while (load_branches & func_vars):
for k in (load_branches & func_vars):
aux_branches.add(k)
load_branches.remove(k)
load_branches.update(_get_variable_names(self._data_config.var_funcs[k]))

_logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(keep_branches))
_logger.debug('[WeightMaker] aux_branches:\n %s', ','.join(aux_branches))
_logger.debug('[WeightMaker] load_branches:\n %s', ','.join(load_branches))

table = _read_files(filelist, load_branches, show_progressbar=True,
treename=self._data_config.treename,
branch_magic=self._data_config.branch_magic, file_magic=self._data_config.file_magic)
table = _apply_selection(table, self._data_config.selection, funcs=self._data_config.var_funcs)
table = _build_new_variables(
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
table = _clean_up(table, self.load_branches - self.keep_branches)
table = _build_new_variables(table, {k: v for k, v in self._data_config.var_funcs.items() if k in aux_branches})
table = table[keep_branches]
return table

def make_weights(self, table):
Expand Down Expand Up @@ -284,7 +291,7 @@ def produce(self, output=None):
table = self.read_file(self._filelist)
wgts = self.make_weights(table)
self._data_config.reweight_hists = wgts
# must also propogate the changes to `data_config.options` so it can be persisted
# must also propagate the changes to `data_config.options` so it can be persisted
self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()}
if output:
_logger.info('Writing YAML file w/ reweighting info to %s' % output)
Expand Down
6 changes: 4 additions & 2 deletions weaver/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def _preprocess(table, data_config, options):
if len(table) == 0:
return []
# define new variables
table = _build_new_variables(table, data_config.var_funcs)
aux_branches = data_config.train_aux_branches if options['training'] else data_config.test_aux_branches
table = _build_new_variables(table, {k: v for k, v in data_config.var_funcs.items() if k in aux_branches})
# check labels
if data_config.label_type == 'simple' and options['training']:
_check_labels(table)
Expand All @@ -108,7 +109,8 @@ def _preprocess(table, data_config, options):


def _load_next(data_config, filelist, load_range, options):
table = _read_files(filelist, data_config.load_branches, load_range, treename=data_config.treename,
load_branches = data_config.train_load_branches if options['training'] else data_config.test_load_branches
table = _read_files(filelist, load_branches, load_range, treename=data_config.treename,
branch_magic=data_config.branch_magic, file_magic=data_config.file_magic)
table, indices = _preprocess(table, data_config, options)
return table, indices
Expand Down

0 comments on commit 80c1fd0

Please sign in to comment.