Skip to content

Commit

Permalink
Merge pull request #19 from IBM/code-dev
Browse files Browse the repository at this point in the history
Code dev
  • Loading branch information
RaulFD-creator authored Aug 15, 2024
2 parents 580aa36 + 0dc6069 commit 4b94895
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
17 changes: 11 additions & 6 deletions autopeptideml/autopeptideml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import optuna
import pandas as pd
import scikitplot as skplt
import scikitplot.metrics as skplt
import sklearn.metrics
from sklearn.model_selection import StratifiedKFold

Expand Down Expand Up @@ -648,17 +648,22 @@ def _make_figures(self, figures_path: str, truths, preds_proba):
new_preds_proba[:, 0] = 1 - preds_proba
new_preds_proba[:, 1] = preds_proba
preds_proba = new_preds_proba
skplt.metrics.plot_confusion_matrix(truths, preds, normalize=False,
title='Confusion Matrix')
skplt.plot_confusion_matrix(truths, preds, normalize=False,
title='Confusion Matrix')
plt.savefig(os.path.join(figures_path, 'confusion_matrix.png'))
plt.close()
skplt.metrics.plot_roc(truths, preds_proba, title='ROC Curve', plot_micro=False, plot_macro=False, classes_to_plot=[1])
skplt.plot_roc(truths, preds_proba, title='ROC Curve',
plot_micro=False, plot_macro=False,
classes_to_plot=[1])
plt.savefig(os.path.join(figures_path, 'roc_curve.png'))
plt.close()
skplt.metrics.plot_precision_recall(truths, preds_proba, title='Precision-Recall Curve', plot_micro=False, classes_to_plot=[1])
skplt.plot_precision_recall(truths, preds_proba,
title='Precision-Recall Curve',
plot_micro=False, classes_to_plot=[1])
plt.savefig(os.path.join(figures_path, 'precision_recall_curve.png'))
plt.close()
skplt.metrics.plot_calibration_curve(truths, [preds_proba], title='Calibration Curve')
skplt.plot_calibration_curve(truths, [preds_proba],
title='Calibration Curve')
plt.savefig(os.path.join(figures_path, 'calibration_curve.png'))
plt.close()

Expand Down
3 changes: 1 addition & 2 deletions autopeptideml/utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def compute_batch(self, batch: list, average_pooling: bool):
decoder_input_ids=inputs['input_ids']
).last_hidden_state
else:
embd_rpr = self.model(**inputs)

embd_rpr = self.model(**inputs).last_hidden_state
output = []
for idx in range(len(batch)):
seq_len = len(batch[idx])
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@
name='autopeptideml',
packages=find_packages(exclude=['examples']),
url='https://ibm.github.io/AutoPeptideML/',
version='0.3.3',
version='0.3.4',
zip_safe=False,
)
27 changes: 27 additions & 0 deletions tests/test_representationengine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from autopeptideml import RepresentationEngine
import numpy as np


def test_esm_family():
re = RepresentationEngine('esm2-8m', batch_size=12)
a = re.compute_representations(['AACFFF'], average_pooling=True)
b = re.compute_representations(['AACFFF', 'AACCF'], average_pooling=True)
c = re.compute_representations(['AACFFF'], average_pooling=False)
assert re.dim() == 320
assert np.array(a).shape == (1, 320)
assert np.array(b).shape == (2, 320)
assert np.array(c).shape == (1, 6, 320)


def test_elnaggar_family():
re = RepresentationEngine('ankh-base', batch_size=12)
a = re.compute_representations(['AACFFF'], average_pooling=True)
assert re.dim() == 768
assert np.array(a).shape == (1, re.dim())


def test_rostlab_family():
re = RepresentationEngine('prot-t5-xl', batch_size=12)
a = re.compute_representations(['AACFFF'], average_pooling=True)
assert re.dim() == 1024
assert np.array(a).shape == (1, re.dim())

0 comments on commit 4b94895

Please sign in to comment.