Skip to content

Commit

Permalink
Merge pull request #42 from IBM/code-dev
Browse files Browse the repository at this point in the history
🛠️🧹: Minor patches and code cleanliness improvements
  • Loading branch information
RaulFD-creator authored Jan 3, 2025
2 parents c7f07df + 3268f5a commit fc2ae8f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 29 deletions.
82 changes: 59 additions & 23 deletions autopeptideml/autopeptideml.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,34 +727,70 @@ def save_models(
self,
best_model: list,
outputdir: str,
id2rep: dict
id2rep: dict,
backend: str = 'onnx'
):
from skl2onnx import to_onnx
from skl2onnx.common.data_types import FloatTensorType
import onnxmltools as onxt
if backend == 'joblib':
try:
import joblib
except ImportError:
raise ImportError(
'This backend requires joblib.',
'Please try: `pip install joblib`'
)
elif backend == 'onnx':
try:
import onnxmltools as onxt
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import to_onnx
except ImportError:
raise ImportError(
'This backend requires onnx.',
'Please try: `pip install onnxmltools skl2onnx`'
)

else:
raise NotImplementedError(f"Backend: {backend} not implemented.",
"Please try: `onnx` or `joblib`.")

raw_data_path = osp.join(outputdir, 'ensemble')
os.makedirs(raw_data_path, exist_ok=True)
if isinstance(id2rep, dict):
X = np.array(list(id2rep.values())[:5])
else:
X = id2rep

variable_type = FloatTensorType([None, X.shape[1]])
for idx, clf in enumerate(best_model['estimators']):
if 'LGBM' in str(clf):
clf_onx = onxt.convert_lightgbm(
clf,
initial_types=[('float_input', variable_type)]
)
elif 'XGB' in str(clf):
clf_onx = onxt.convert_xgboost(
clf,
initial_types=[('float_input', variable_type)]
)
if backend == 'onnx':
if isinstance(id2rep, dict):
X = np.array(list(id2rep.values())[:5])
else:
clf_onx = to_onnx(clf, X)
with open(osp.join(raw_data_path, f"{idx}.onnx"), "wb") as f:
f.write(clf_onx.SerializeToString())
X = id2rep

variable_type = FloatTensorType([None, X.shape[1]])
for idx, clf in enumerate(best_model['estimators']):
if 'LGBM' in str(clf):
clf_onx = onxt.convert_lightgbm(
clf,
initial_types=[('float_input', variable_type)]
)
elif 'XGB' in str(clf):
clf_onx = onxt.convert_xgboost(
clf,
initial_types=[('float_input', variable_type)]
)
else:
clf_onx = to_onnx(clf, X)

if 'class' in str(clf).lower():
name = f'{idx}_class.onnx'
else:
name = f'{idx}_reg.onnx'
with open(osp.join(raw_data_path, name), "wb") as f:

f.write(clf_onx.SerializeToString())
else:
for idx, clf in enumerate(best_model['estimators']):
if 'class' in str(clf).lower():
name = f'{idx}_class.onnx'
else:
name = f'{idx}_reg.onnx'
joblib.dump(clf, open(osp.join(raw_data_path, name)), 'wb')

def _onnx_prediction(
self,
Expand Down
15 changes: 9 additions & 6 deletions autopeptideml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def build_model(
apml.save_models(
best_model=model,
outputdir=osp.join(outputdir, 'ensemble'),
id2rep=id2rep
id2rep=id2rep,
backend=model_save_backend
)
if verbose is True:
print(results)
Expand All @@ -215,7 +216,8 @@ def predict(
threads: int = typer.Option(cpu_count(), help="Number of threads to use."),
plm: str = typer.Option("esm2-8m", help="PLM for computing peptide representations. Check GitHub Repository for available options."),
plm_batch_size: int = typer.Option(12, help="Batch size for PLM."),
device: str = typer.Option(None, help="Device where the representations should be computed.")
plm_device: str = typer.Option(None, help="Device where the representations should be computed."),
model_save_backend: str = typer.Option("onnx", help="Backend for storing models. Options: `onnx` or `joblib`"),
) -> pd.DataFrame:
"""
Predicts peptide representations and outputs predictions using a pre-trained Peptide Language Model (PLM).
Expand All @@ -226,7 +228,7 @@ def predict(
----------
dataset : str
Path to the dataset to be processed. The dataset should be in a format compatible with APML.
ensemble : str, optional
Path to a directory containing previous APML results for ensemble predictions. If `None`, no ensemble is used.
Default is `None`.
Expand Down Expand Up @@ -279,11 +281,12 @@ def predict(
)
"""
re = RepresentationEngine(plm, plm_batch_size)
if device is not None:
re.move_to_device(device)
if plm_device is not None:
re.move_to_device(plm_device)
apml = AutoPeptideML(verbose, threads, 1)
df = apml.curate_dataset(dataset, outputdir)
return apml.predict(df, re, ensemble, outputdir)
return apml.predict(df, re, ensemble, outputdir,
backend=model_save_backend)


def _build_model():
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
'lightgbm',
'xgboost',
'mdpdf',
'onnxmltools',
'skl2onnx',
'onnxruntime',
'hestia-ood>=0.0.34'
]

Expand Down

0 comments on commit fc2ae8f

Please sign in to comment.