From fceb684982598f8dc926419300c7ae5d69b618a8 Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 20 Dec 2023 23:15:09 -0500 Subject: [PATCH 1/3] add example for representation --- .../extracting-representation-molfeat.ipynb | 336 ++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 337 insertions(+) create mode 100644 docs/tutorials/extracting-representation-molfeat.ipynb diff --git a/docs/tutorials/extracting-representation-molfeat.ipynb b/docs/tutorials/extracting-representation-molfeat.ipynb new file mode 100644 index 0000000..43e1b7d --- /dev/null +++ b/docs/tutorials/extracting-representation-molfeat.ipynb @@ -0,0 +1,336 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import safe\n", + "import torch\n", + "import datamol as dm\n", + "import types\n", + "from molfeat.trans.pretrained import PretrainedMolTransformer\n", + "from molfeat.trans.pretrained import PretrainedHFTransformer\n", + "\n", + "from molfeat.trans.pretrained.hf_transformers import HFModel\n", + "from safe.trainer.model import SAFEDoubleHeadsModel\n", + "from safe.tokenizer import SAFETokenizer\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the SAFE-GPT model into molfeat\n", + "\n", + "Because the SAFE model is not a standard HuggingFace `transformers` model, we need to wrap it.\n", + "\n", + "Why are we doing this ? Because we want to leverage the structure of `molfeat` and not have to write our own pooling for the model. This can be done by using the huggingface molecule transformer `PretrainedHFTransformer` rather than the general purpose pretrained model class `PretrainedMolTransformer` where we will have to define our own `_embed` and `_convert` function." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "safe_model = SAFEDoubleHeadsModel.from_pretrained(\"datamol-io/safe-gpt\")\n", + "safe_tokenizer = SAFETokenizer.from_pretrained(\"datamol-io/safe-gpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now need to build the `molfeat`'s `HFModel` instance by wrapping our model." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "safe_hf_model = HFModel.from_pretrained(safe_model, safe_tokenizer.get_pretrained())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can put the above process in the `__init__` of the `SAFEMolTransformer` if you wish as we will be doing below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building the SAFE Molecule Transformers\n", + "\n", + "We have multiple options here, we can override the `_convert` method or even the `_embed` method but the best thing about `molfeat` is how flexible it is and all the shortcuts it provides. \n", + "\n", + "In this case, we just need to change the custom \n", + "# so really we just need our custom converter" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-20 22:57:39.310\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mmolfeat.trans.base\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m51\u001b[0m - \u001b[33m\u001b[1mThe 'SAFEMolTransformer' interaction has been superseded by a new class with id 0x2ad77d6a0\u001b[0m\n" + ] + } + ], + "source": [ + "class SAFEMolTransformer(PretrainedHFTransformer):\n", + " \"\"\"Build the SAFE Molecule transformers, the only thing we need to define is \n", + " how we convert the input molecules into the safe format\"\"\"\n", + " def __init__(self, kind=None, notation=\"safe\", **kwargs):\n", + " if kind is None:\n", + " # we load the default SAFE model if the exact SAFE GPT model \n", + " # to use is not provided\n", + " safe_model = SAFEDoubleHeadsModel.from_pretrained(\"datamol-io/safe-gpt\")\n", + " safe_tokenizer = SAFETokenizer.from_pretrained(\"datamol-io/safe-gpt\")\n", + " kind = HFModel.from_pretrained(safe_model, safe_tokenizer.get_pretrained())\n", + " super().__init__(kind, notation=None, **kwargs)\n", + " # now we change the internal converter\n", + " # overriding the internal converter of SmilesConverter leverages the exception handling\n", + " # The SAFE-GPT model was trained on a slightly different splitting algorithm compared to the default BRICS\n", + " # this does not change anything in theory, it just try harder to break bonds even if there are no BRICS bonds.\n", + " self.converter.converter = types.SimpleNamespace(decode=safe.decode, encode=safe.utils.convert_to_safe)\n", + " # you could also do any of the following:\n", + " # self.converter = types.SimpleNamespace(decode=safe.decode, encode=safe.encode)\n", + " # self.converter = safe # the safe module\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's use the GPT pooler which uses the last non padding token (often `eos`) since the model is GPT2 like. For other options, see: https://molfeat-docs.datamol.io/stable/api/molfeat.utils.html#pooling" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SAFEMolTransformer(dtype=np.float32)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "SAFEMolTransformer(dtype=np.float32)" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Let's use the GPT pooling method\n", + "safe_transformers = SAFEMolTransformer(pooling=\"gpt\")\n", + "safe_transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [], + "source": [ + "mols = dm.data.freesolv().iloc[:10].smiles.values" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0.05216356, 0.10754181, 0.07509107, ..., 0.04756968,\n", + " -0.08228929, -0.11568106],\n", + " [ 0.02449008, 0.04048932, 0.14489463, ..., 0.11410899,\n", + " -0.02203353, 0.08706839],\n", + " [-0.07425696, 0.11859665, 0.19010407, ..., 0.10526019,\n", + " 0.08878426, -0.06609854],\n", + " ...,\n", + " [ 0.07867863, 0.19300285, 0.23054805, ..., -0.00737952,\n", + " 0.07542405, 0.00289541],\n", + " [ 0.12092628, -0.01785688, 0.19791883, ..., 0.13796932,\n", + " 0.11520796, -0.15333697],\n", + " [-0.02005584, 0.13946685, 0.18568742, ..., 0.07080407,\n", + " 0.06991849, -0.07151204]], dtype=float32)" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "safe_transformers(mols)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Basic Test" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestRegressor\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.pipeline import Pipeline\n", + "\n", + "df = dm.data.freesolv()\n", + "df[\"safe\"] = df[\"smiles\"].apply(safe_transformers.converter.encode)\n", + "df = df.dropna(subset=\"safe\")\n", + "# we have to remove the molecules that cannot be converted \n", + "# (no breakable bonds with our default methodology)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "X, y = df[\"smiles\"].values, df[\"expt\"].values\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=25, test_size=0.2)\n", + "\n", + "# The Molfeat transformer seemingly integrates with Scikit-learn Pipeline!\n", + "pipe = Pipeline([(\"feat\", safe_transformers), (\"rf\", RandomForestRegressor())])" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [], + "source": [ + "with dm.without_rdkit_log():\n", + " pipe.fit(X_train, y_train)\n", + " score = pipe.score(X_test, y_test)\n", + " y_pred = pipe.predict(X_test)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "R2 score 0.5082630204054333\n" + ] + } + ], + "source": [ + "print(\"R2 score\", score)" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'Preds')" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.scatter(y_test, y_pred)\n", + "ax.set_xlabel(\"Target\")\n", + "ax.set_ylabel(\"Preds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Not really a great result. Any other model in `molfeat` would do better." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tips\n", + "\n", + "1. Make sure that your inputs are SMILES or RDKit Molecules.\n", + "2. If you are getting an error coming from some tokenization step, that means that you are likely getting `None` molecules at some steps in the conversion to SAFE. This can happen if there your slicing algorithm of choice is not working. In that case, please filter your datasets to remove molecules that fails the encoding steps first. You can always use the very robus `safe.utils.convert_to_safe`, which augment default BRICS slicing with some graph partitioning algorithm.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "safe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mkdocs.yml b/mkdocs.yml index d302e8d..379083b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -16,6 +16,7 @@ nav: - Getting Started: tutorials/getting-started.ipynb - Molecular design: tutorials/design-with-safe.ipynb - How it works: tutorials/how-it-works.ipynb + - Extracting representation (molfeat): tutorials/extracting-representation-molfeat.ipynb - API: - SAFE: api/safe.md - Visualization: api/safe.viz.md From 529003341ea75a50c172a9592307b1cc2f46b595 Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 20 Dec 2023 23:23:11 -0500 Subject: [PATCH 2/3] wip --- .../extracting-representation-molfeat.ipynb | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/tutorials/extracting-representation-molfeat.ipynb b/docs/tutorials/extracting-representation-molfeat.ipynb index 43e1b7d..17977c2 100644 --- a/docs/tutorials/extracting-representation-molfeat.ipynb +++ b/docs/tutorials/extracting-representation-molfeat.ipynb @@ -128,32 +128,32 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
SAFEMolTransformer(dtype=np.float32)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
SAFEMolTransformer(dtype=np.float32)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "SAFEMolTransformer(dtype=np.float32)" ] }, - "execution_count": 98, + "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Let's use the GPT pooling method\n", - "safe_transformers = SAFEMolTransformer(pooling=\"gpt\")\n", + "# Let's use the GPT pooling method and only take the last hidden layer\n", + "safe_transformers = SAFEMolTransformer(pooling=\"gpt\", concat_layers=[-1])\n", "safe_transformers" ] }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 118, "metadata": {}, "outputs": [ { @@ -183,7 +183,7 @@ " 0.06991849, -0.07151204]], dtype=float32)" ] }, - "execution_count": 100, + "execution_count": 118, "metadata": {}, "output_type": "execute_result" } @@ -201,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -218,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -232,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -244,24 +244,24 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 122, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "R2 score 0.5082630204054333\n" + "R2 score: 0.4971483821661925\n" ] } ], "source": [ - "print(\"R2 score\", score)" + "print(\"R2 score:\", score)" ] }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 123, "metadata": {}, "outputs": [ { @@ -270,13 +270,13 @@ "Text(0, 0.5, 'Preds')" ] }, - "execution_count": 105, + "execution_count": 123, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] From 527dbf633ccdafd6896857ae5ebf1ddd67bda0bf Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 20 Dec 2023 23:29:41 -0500 Subject: [PATCH 3/3] do not run the features notebook --- tests/test_notebooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index e4e4d9d..23a0ab5 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -13,7 +13,7 @@ NOTEBOOK_PATHS = list(filter(lambda x: x.name not in DISABLE_NOTEBOOKS, NOTEBOOK_PATHS)) # Discard some notebooks -NOTEBOOKS_TO_DISCARD = ["Basic_Concepts.ipynb"] +NOTEBOOKS_TO_DISCARD = ["extracting-representation-molfeat.ipynb"] NOTEBOOK_PATHS = list(filter(lambda x: x.name not in NOTEBOOKS_TO_DISCARD, NOTEBOOK_PATHS))