-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathxgb_models.py
132 lines (108 loc) · 4.73 KB
/
xgb_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Utilities for handling XGBoost model files within the MS²PIP prediction framework."""
import hashlib
import logging
import os
import urllib.request
from itertools import islice
import numpy as np
import xgboost as xgb
from ms2pip.exceptions import InvalidXGBoostModelError
logger = logging.getLogger(__name__)
def validate_requested_xgb_model(xgboost_model_files, xgboost_model_hashes, model_dir):
"""Validate requested XGBoost models, and download if necessary"""
for _, model_file in xgboost_model_files.items():
if not _check_model_presence(model_file, xgboost_model_hashes[model_file], model_dir):
_download_model(model_file, xgboost_model_hashes[model_file], model_dir)
def get_predictions_xgb(features, num_ions, model_params, model_dir, processes=1):
"""
Get predictions starting from feature vectors in DMatrix object.
Parameters
----------
features: xgb.DMatrix, np.ndarray
Feature vectors in DMatrix object or as Numpy array.
num_ions: list[int]
List with number of ions (per series) for each peptide, i.e. peptide length - 1
model_params: dict
Model configuration as defined in ms2pip.ms2pipC.MODELS.
model_dir: str
Directory where model files are stored.
processes: int
Number of CPUs to use in multiprocessing
"""
# Init models
xgboost_models = _initialize_xgb_models(
model_params["xgboost_model_files"],
model_dir,
processes,
)
if isinstance(features, np.ndarray):
features = xgb.DMatrix(features)
logger.debug("Predicting intensities from XGBoost model files...")
prediction_dict = {}
for ion_type, xgb_model in xgboost_models.items():
# Get predictions from XGBoost model
preds = xgb_model.predict(features)
preds = preds.clip(min=np.log2(0.001)) # Clip negative intensities
xgb_model.__del__()
# Reshape into arrays for each peptide
if ion_type.lower() in ["x", "y", "y2", "z"]:
preds = _split_list_by_lengths(preds, num_ions, reverse=True)
elif ion_type.lower() in ["a", "b", "b2", "c"]:
preds = _split_list_by_lengths(preds, num_ions, reverse=False)
else:
raise ValueError(f"Unsupported ion_type: {ion_type}")
prediction_dict[ion_type] = preds
# Convert to list per peptide with dicts per ion type
num_peptides = len(list(prediction_dict.values())[0])
predictions = [{k: v[i] for k, v in prediction_dict.items()} for i in range(num_peptides)]
return predictions
def _split_list_by_lengths(list_in, lengths, reverse=False):
"""Split list of predictions into sublists per peptide given their lengths."""
list_in = iter(list_in)
if reverse:
list_out = [np.array(list(islice(list_in, e)), dtype=np.float32)[::-1] for e in lengths]
else:
list_out = [np.array(list(islice(list_in, e)), dtype=np.float32) for e in lengths]
return list_out
def _check_model_presence(model, model_hash, model_dir):
"""Check whether XGBoost model file is downloaded."""
filename = os.path.join(model_dir, model)
if not os.path.isfile(filename):
return False
return _check_model_integrity(filename, model_hash)
def _download_model(model, model_hash, model_dir):
"""Download the xgboost model from the Genesis server."""
if not os.path.isdir(model_dir):
os.mkdir(model_dir)
filename = os.path.join(model_dir, model)
logger.info(f"Downloading {model} to {filename}...")
urllib.request.urlretrieve(
os.path.join("http://genesis.ugent.be/uvpublicdata/ms2pip/", model), filename
)
if not _check_model_integrity(filename, model_hash):
raise InvalidXGBoostModelError()
def _check_model_integrity(filename, model_hash):
"""Check that models are correctly downloaded."""
sha1_hash = hashlib.sha1()
with open(filename, "rb") as model_file:
while True:
chunk = model_file.read(16 * 1024)
if not chunk:
break
sha1_hash.update(chunk)
if sha1_hash.hexdigest() == model_hash:
return True
else:
logger.warn("Model hash not recognized.")
return False
def _initialize_xgb_models(xgboost_model_files, model_dir, nthread) -> dict:
"""Initialize xgboost models and return them in a dict with ion types as keys."""
xgb.set_config(verbosity=0)
xgboost_models = {}
for ion_type in xgboost_model_files.keys():
model_file = os.path.join(model_dir, xgboost_model_files[ion_type])
logger.debug(f"Initializing model from file: `{model_file}`")
xgb_model = xgb.Booster({"nthread": nthread})
xgb_model.load_model(model_file)
xgboost_models[ion_type] = xgb_model
return xgboost_models