diff --git a/notebooks/FineMappingSimmuations.ipynb b/notebooks/FineMappingSimmuations.ipynb
new file mode 100644
index 000000000..d81b268e3
--- /dev/null
+++ b/notebooks/FineMappingSimmuations.ipynb
@@ -0,0 +1,425 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Simulations to benchmark the fine-mapping"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The simulations are based on the specific GnomAD LD matrix from the 1Mb region on chromosome 7.\n",
+ "\n",
+ "At each iteration of the simulation we randomly select n_causal causal variants and generate Z-scores. We then perform fine mapping using GentroPy functions and examine the output.\n",
+ "\n",
+ "We expect all selected variants to be presented in detected credible sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "
Loading BokehJS ...\n",
+ "
\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n function drop(id) {\n const view = Bokeh.index.get_by_id(id)\n if (view != null) {\n view.model.document.clear()\n Bokeh.index.delete(view)\n }\n }\n\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n\n // Clean up Bokeh references\n if (id != null) {\n drop(id)\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim()\n drop(id)\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"\\n\"+\n \"
\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"
\\n\"+\n \"
\\n\"+\n \"- re-rerun `output_notebook()` to attempt to load from CDN again, or
\\n\"+\n \"- use INLINE resources instead, as so:
\\n\"+\n \"
\\n\"+\n \"
\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"
\\n\"+\n \"
\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"f1ddd57f-afd5-4c09-9706-19e3f06fa51c\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-3.3.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.3.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.3.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.3.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-3.3.0.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"f1ddd57f-afd5-4c09-9706-19e3f06fa51c\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));",
+ "application/vnd.bokehjs_load.v0+json": ""
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Setting default log level to \"WARN\".\n",
+ "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "24/05/21 18:05:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "from gentropy.common.session import Session\n",
+ "from gentropy.finemapping_simulations import FineMappingSimulations\n",
+ "\n",
+ "session = Session()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ld_matrix = np.load('/Users/yt4/Projects/ot_data/tmp/ld_matrix.npy')\n",
+ "ld_index=session.spark.read.parquet(\"/Users/yt4/Projects/ot_data/tmp/ld_index\")\n",
+ "ld_matrix_for_sim=ld_matrix[0:500,:][:,0:500]\n",
+ "ld_index_for_sim=ld_index.limit(500)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Number of causal variants = 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_causal=1\n",
+ "x1=FineMappingSimulations.SimulationLoop(\n",
+ " n_iter=100,\n",
+ " n_causal=n_causal,\n",
+ " session=session,\n",
+ " he2_reggen=0.003,\n",
+ " sample_size=100_000,\n",
+ " ld_matrix_for_sim=ld_matrix_for_sim,\n",
+ " ld_index=ld_index_for_sim\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'successful_runs': 76, 'number_of_cs': 76, 'expected_results': 76, 'false_positives': 0.013157894736842105, 'accuracy': 0.9868421052631579, 'accuracy_lead': 0.6447368421052632, 'sensitivity': 0.9868421052631579}\n",
+ "{'successful_runs': 76, 'number_of_cs': 76, 'expected_results': 76, 'false_positives': 0.013157894736842105, 'accuracy': 0.9868421052631579, 'accuracy_lead': 0.6447368421052632, 'sensitivity': 0.9868421052631579}\n",
+ "{'successful_runs': 76, 'number_of_cs': 76, 'expected_results': 76, 'false_positives': 0.013157894736842105, 'accuracy': 0.9868421052631579, 'accuracy_lead': 0.6447368421052632, 'sensitivity': 0.9868421052631579}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x1,n_causal=n_causal))\n",
+ "x2=x1[(x1[\"pValueExponent\"]<=-6) | (x1[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x2,n_causal=n_causal))\n",
+ "x3=x2[(x2[\"purityMinR2\"]>=0.25) | (x2[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x3,n_causal=n_causal))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Number of causal variants = 3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_causal=3\n",
+ "x1=FineMappingSimulations.SimulationLoop(\n",
+ " n_iter=100,\n",
+ " n_causal=n_causal,\n",
+ " session=session,\n",
+ " he2_reggen=0.003,\n",
+ " sample_size=100_000,\n",
+ " ld_matrix_for_sim=ld_matrix_for_sim,\n",
+ " ld_index=ld_index_for_sim\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'successful_runs': 89, 'number_of_cs': 177, 'expected_results': 267, 'false_positives': 0.062146892655367235, 'accuracy': 0.9378531073446328, 'accuracy_lead': 0.6666666666666666, 'sensitivity': 0.6217228464419475}\n",
+ "{'successful_runs': 89, 'number_of_cs': 172, 'expected_results': 267, 'false_positives': 0.05232558139534884, 'accuracy': 0.9476744186046512, 'accuracy_lead': 0.6802325581395349, 'sensitivity': 0.6104868913857678}\n",
+ "{'successful_runs': 89, 'number_of_cs': 161, 'expected_results': 267, 'false_positives': 0.049689440993788817, 'accuracy': 0.9503105590062112, 'accuracy_lead': 0.6832298136645962, 'sensitivity': 0.5730337078651685}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x1,n_causal=n_causal))\n",
+ "x2=x1[(x1[\"pValueExponent\"]<=-6) | (x1[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x2,n_causal=n_causal))\n",
+ "x3=x2[(x2[\"purityMinR2\"]>=0.25) | (x2[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x3,n_causal=n_causal))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## CARMA without noise"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_causal=1\n",
+ "x1=FineMappingSimulations.SimulationLoop(\n",
+ " n_iter=100,\n",
+ " n_causal=n_causal,\n",
+ " session=session,\n",
+ " he2_reggen=0.003,\n",
+ " sample_size=100_000,\n",
+ " ld_matrix_for_sim=ld_matrix_for_sim,\n",
+ " ld_index=ld_index_for_sim,\n",
+ " noise=False,\n",
+ " run_carma=True\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'successful_runs': 74, 'number_of_cs': 74, 'expected_results': 74, 'false_positives': 0.04054054054054054, 'accuracy': 0.9594594594594594, 'accuracy_lead': 0.7027027027027027, 'sensitivity': 0.9594594594594594}\n",
+ "{'successful_runs': 74, 'number_of_cs': 74, 'expected_results': 74, 'false_positives': 0.04054054054054054, 'accuracy': 0.9594594594594594, 'accuracy_lead': 0.7027027027027027, 'sensitivity': 0.9594594594594594}\n",
+ "{'successful_runs': 74, 'number_of_cs': 74, 'expected_results': 74, 'false_positives': 0.04054054054054054, 'accuracy': 0.9594594594594594, 'accuracy_lead': 0.7027027027027027, 'sensitivity': 0.9594594594594594}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x1,n_causal=n_causal))\n",
+ "x2=x1[(x1[\"pValueExponent\"]<=-6) | (x1[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x2,n_causal=n_causal))\n",
+ "x3=x2[(x2[\"purityMinR2\"]>=0.25) | (x2[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x3,n_causal=n_causal))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## No noise, but with CARMA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_causal=3\n",
+ "x1=FineMappingSimulations.SimulationLoop(\n",
+ " n_iter=100,\n",
+ " n_causal=n_causal,\n",
+ " session=session,\n",
+ " he2_reggen=0.003,\n",
+ " sample_size=100_000,\n",
+ " ld_matrix_for_sim=ld_matrix_for_sim,\n",
+ " ld_index=ld_index_for_sim,\n",
+ " noise=False,\n",
+ " run_carma=True\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'successful_runs': 91, 'number_of_cs': 172, 'expected_results': 273, 'false_positives': 0.10465116279069768, 'accuracy': 0.8953488372093024, 'accuracy_lead': 0.6453488372093024, 'sensitivity': 0.5641025641025641}\n",
+ "{'successful_runs': 91, 'number_of_cs': 162, 'expected_results': 273, 'false_positives': 0.09259259259259259, 'accuracy': 0.9074074074074074, 'accuracy_lead': 0.6666666666666666, 'sensitivity': 0.5384615384615384}\n",
+ "{'successful_runs': 91, 'number_of_cs': 150, 'expected_results': 273, 'false_positives': 0.07333333333333333, 'accuracy': 0.9266666666666666, 'accuracy_lead': 0.6933333333333334, 'sensitivity': 0.5091575091575091}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x1,n_causal=n_causal))\n",
+ "x2=x1[(x1[\"pValueExponent\"]<=-6) | (x1[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x2,n_causal=n_causal))\n",
+ "x3=x2[(x2[\"purityMinR2\"]>=0.25) | (x2[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x3,n_causal=n_causal))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Adding noise"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### One varaint, noise, no CARMA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_causal=1\n",
+ "x1=FineMappingSimulations.SimulationLoop(\n",
+ " n_iter=100,\n",
+ " n_causal=n_causal,\n",
+ " session=session,\n",
+ " he2_reggen=0.005,\n",
+ " sample_size=100_000,\n",
+ " ld_matrix_for_sim=ld_matrix_for_sim,\n",
+ " ld_index=ld_index_for_sim,\n",
+ " noise=True,\n",
+ " run_carma=False,\n",
+ " scale_noise=2,\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'successful_runs': 76, 'number_of_cs': 115, 'expected_results': 76, 'false_positives': 0.48695652173913045, 'accuracy': 0.5130434782608696, 'accuracy_lead': 0.4, 'sensitivity': 0.7763157894736842}\n",
+ "{'successful_runs': 76, 'number_of_cs': 112, 'expected_results': 76, 'false_positives': 0.4732142857142857, 'accuracy': 0.5267857142857143, 'accuracy_lead': 0.4107142857142857, 'sensitivity': 0.7763157894736842}\n",
+ "{'successful_runs': 76, 'number_of_cs': 111, 'expected_results': 76, 'false_positives': 0.46846846846846846, 'accuracy': 0.5315315315315315, 'accuracy_lead': 0.4144144144144144, 'sensitivity': 0.7763157894736842}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x1,n_causal=n_causal))\n",
+ "x2=x1[(x1[\"pValueExponent\"]<=-6) | (x1[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x2,n_causal=n_causal))\n",
+ "x3=x2[(x2[\"purityMinR2\"]>=0.25) | (x2[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x3,n_causal=n_causal))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### One varaint, noise and CARMA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_causal=1\n",
+ "x1=FineMappingSimulations.SimulationLoop(\n",
+ " n_iter=100,\n",
+ " n_causal=n_causal,\n",
+ " session=session,\n",
+ " he2_reggen=0.005,\n",
+ " sample_size=100_000,\n",
+ " ld_matrix_for_sim=ld_matrix_for_sim,\n",
+ " ld_index=ld_index_for_sim,\n",
+ " noise=True,\n",
+ " run_carma=True,\n",
+ " scale_noise=2,\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'successful_runs': 86, 'number_of_cs': 99, 'expected_results': 86, 'false_positives': 0.30303030303030304, 'accuracy': 0.696969696969697, 'accuracy_lead': 0.5353535353535354, 'sensitivity': 0.8023255813953488}\n",
+ "{'successful_runs': 86, 'number_of_cs': 95, 'expected_results': 86, 'false_positives': 0.2736842105263158, 'accuracy': 0.7263157894736842, 'accuracy_lead': 0.5578947368421052, 'sensitivity': 0.8023255813953488}\n",
+ "{'successful_runs': 86, 'number_of_cs': 93, 'expected_results': 86, 'false_positives': 0.26881720430107525, 'accuracy': 0.7311827956989247, 'accuracy_lead': 0.5698924731182796, 'sensitivity': 0.7906976744186046}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x1,n_causal=n_causal))\n",
+ "x2=x1[(x1[\"pValueExponent\"]<=-6) | (x1[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x2,n_causal=n_causal))\n",
+ "x3=x2[(x2[\"purityMinR2\"]>=0.25) | (x2[\"credibleSetIndex\"]==1)]\n",
+ "print(FineMappingSimulations.ProvideSummary(cred_sets=x3,n_causal=n_causal))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "gentropy-krNFZEZg-py3.10",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/gentropy/finemapping_simulations.py b/src/gentropy/finemapping_simulations.py
new file mode 100644
index 000000000..66a8f7855
--- /dev/null
+++ b/src/gentropy/finemapping_simulations.py
@@ -0,0 +1,260 @@
+"""Step/stash of functions to run a simulations to benchmark finemapping."""
+
+from typing import Any
+
+import numpy as np
+import pandas as pd
+from pyspark.sql import DataFrame
+from pyspark.sql.functions import array_contains, col, when
+from pyspark.sql.types import DoubleType, StringType, StructField, StructType
+from scipy.stats import chi2
+
+from gentropy.common.session import Session
+from gentropy.susie_finemapper import SusieFineMapperStep
+
+
+class FineMappingSimulations:
+ """The module describes functions for running fine-mapping simulations and benchmarking."""
+
+ @staticmethod
+ def ProvideSummary(cred_sets: DataFrame, n_causal: int) -> dict[str, Any]:
+ """Provides summary for the simulation results.
+
+ Args:
+ cred_sets (DataFrame): DataFrame containing the credible sets.
+ n_causal (int): Number of causal SNPs.
+
+ Returns:
+ dict[str, Any]: Dictionary containing the summary.
+ """
+ return {
+ "successful_runs": cred_sets["studyId"].nunique(),
+ "number_of_cs": len(cred_sets["is_in_X"]),
+ "expected_results": n_causal * cred_sets["studyId"].nunique(),
+ "false_positives": (len(cred_sets["is_in_X"]) - sum(cred_sets["is_in_X"]))
+ / len(cred_sets["is_in_X"]),
+ "accuracy": sum(cred_sets["is_in_X"]) / len(cred_sets["is_in_X"]),
+ "accuracy_lead": sum(cred_sets["is_in_lead"])
+ / len(cred_sets["is_in_lead"]),
+ "sensitivity": sum(cred_sets["is_in_X"])
+ / (n_causal * cred_sets["studyId"].nunique()),
+ }
+
+ @staticmethod
+ def SimulationLoop(
+ n_iter: int,
+ ld_index: DataFrame,
+ n_causal: int,
+ ld_matrix_for_sim: np.ndarray,
+ session: Session,
+ he2_reggen: float,
+ sample_size: int,
+ noise: bool = False,
+ scale_noise: float = 1,
+ run_carma: bool = False,
+ run_sumstat_imputation: bool = False,
+ prop_of_snps_to_noise: float = 0.1,
+ ) -> DataFrame:
+ """Run a simulation cycle.
+
+ Args:
+ n_iter (int): Number of iterations.
+ ld_index (DataFrame): DataFrame containing the LD index.
+ n_causal (int): Number of causal SNPs.
+ ld_matrix_for_sim (np.ndarray): LD matrix.
+ session (Session): Session object.
+ he2_reggen (float): Heritability explained by the combined effect of the region and gene.
+ sample_size (int): Sample size.
+ noise (bool, optional): Add noise to the simulation. Defaults to False.
+ scale_noise (float, optional): Scale of the noise. Defaults to 1.
+ run_carma (bool, optional): Run CARMA. Defaults to False.
+ run_sumstat_imputation (bool, optional): Run summary statistics imputation. Defaults to False.
+ prop_of_snps_to_noise (float, optional): Proportion of SNPs to add noise to. Defaults to 0.1.
+
+ Returns:
+ DataFrame: DataFrame containing the credible sets.
+ """
+ pd.DataFrame.iteritems = pd.DataFrame.items
+
+ ld_index_pd = ld_index.toPandas()
+ counter = 1
+ cred_sets = None
+ iteration = 0
+ column_list = [
+ "credibleSetIndex",
+ "studyLocusId",
+ "studyId",
+ "region",
+ "exploded_locus",
+ "variantId",
+ "chromosome",
+ "position",
+ "credibleSetlog10BF",
+ "purityMeanR2",
+ "purityMinR2",
+ "zScore",
+ "pValueMantissa",
+ "pValueExponent",
+ "is_in_X",
+ "is_in_lead",
+ ]
+ for iteration in range(n_iter):
+ x_cycle = FineMappingSimulations.SimSumStatFromLD(
+ n_causal=n_causal,
+ he2_reggen=he2_reggen,
+ n=sample_size,
+ U=ld_matrix_for_sim,
+ noise=noise,
+ scale_noise=scale_noise,
+ )
+
+ if sum(x_cycle["P"] <= 5e-8) > 0:
+ df = pd.DataFrame(
+ {"z": x_cycle["Z"], "variantId": ld_index_pd["variantId"]}
+ )
+ schema = StructType(
+ [
+ StructField("z", DoubleType(), True),
+ StructField("variantId", StringType(), True),
+ ]
+ )
+ df_spark = session.spark.createDataFrame(df, schema=schema)
+
+ j = ""
+ for ii in ld_index_pd["variantId"][x_cycle["indexes"]].tolist():
+ j = j + str(ii) + ","
+
+ CS_sim = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes(
+ GWAS_df=df_spark,
+ ld_index=ld_index,
+ gnomad_ld=ld_matrix_for_sim,
+ L=10,
+ session=session,
+ studyId="sim" + str(iteration),
+ region=j,
+ susie_est_tausq=False,
+ run_carma=run_carma,
+ run_sumstat_imputation=run_sumstat_imputation,
+ carma_time_limit=600,
+ imputed_r2_threshold=0.9,
+ ld_score_threshold=5,
+ sum_pips=0.99,
+ primary_signal_pval_threshold=1e-2,
+ secondary_signal_pval_threshold=1e-2,
+ purity_mean_r2_threshold=0,
+ purity_min_r2_threshold=0,
+ cs_lbf_thr=2,
+ )
+ cred_set = CS_sim["study_locus"].df
+
+ X = ld_index_pd["variantId"][x_cycle["indexes"]].tolist()
+
+ cred_set = cred_set.withColumn("exploded_locus", col("locus.variantId"))
+ # Create a condition for each element in X
+ conditions = [array_contains(col("exploded_locus"), x) for x in X]
+ # Combine the conditions using the | operator
+ combined_condition = conditions[0]
+ for condition in conditions[1:]:
+ combined_condition = combined_condition | condition
+ # Create a new column that is True if any condition is True and False otherwise
+ cred_set = cred_set.withColumn("is_in_X", combined_condition)
+
+ cred_set = cred_set.withColumn(
+ "is_in_lead", when(col("variantId").isin(X), 1).otherwise(0)
+ )
+
+ cred_set = cred_set.toPandas()
+ cred_set = cred_set[column_list]
+
+ if counter == 1:
+ cred_sets = cred_set
+ else:
+ # cred_sets = cred_sets.unionByName(cred_set)
+ cred_sets = pd.concat([cred_sets, cred_set], axis=0)
+ # cred_sets=cred_sets.merge(cred_set)
+ counter = counter + 1
+
+ return cred_sets
+
+ @staticmethod
+ def SimSumStatFromLD(
+ n_causal: int,
+ he2_reggen: float,
+ U: np.ndarray,
+ n: int,
+ noise: bool = False,
+ scale_noise: float = 1,
+ ) -> dict[str, Any]:
+ """Simulates summary statistics (vector of Z-scores) using numbr of causla SNPs and LD matrix as input.
+
+ Args:
+ n_causal (int): number of causal snps.
+ he2_reggen (float): Heritability explained by the combined effect of the region and gene.
+ U (np.ndarray): LD.
+ n (int): Sample size.
+ noise (bool, optional): Add noise to the simulation. Defaults to False.
+ scale_noise (float, optional): Scale of the noise. Defaults to 1.
+
+ Returns:
+ dict[str, Any]: Dictionary containing the simulated summary statistics.
+ """
+ # Calculate the total number of SNPs in analysis
+ M = U.shape[0]
+
+ # Calculate heritability explained by one causal SNP
+ Tau = n * he2_reggen / n_causal
+
+ # Simulate the causal status of SNPs
+ indexes = np.random.choice(np.arange(M), size=n_causal, replace=False)
+ cc = np.repeat(0, M)
+ cc[indexes] = 1
+
+ # Simulate joint z-statistics
+ jz = np.zeros(M)
+ x = np.random.normal(loc=0, scale=1, size=n_causal)
+ jz[cc == 1] = x * np.sqrt(Tau)
+
+ # Simulate GWAS z-statistics
+ muz = U @ jz
+ GWASz = np.random.multivariate_normal(mean=muz, cov=U, size=1)
+
+ GWASz = GWASz.flatten()
+
+ if noise:
+ # M1 = int(M * prop_of_snps_to_noise)
+ # indexes_causal = indexes[
+ # np.random.choice(np.arange(len(indexes)), size=1, replace=False)
+ # ]
+ # indexes_noise = np.random.choice(np.arange(M), size=M1, replace=False)
+ # combined = np.concatenate((indexes_causal, indexes_noise))
+ # unique_elements = np.unique(combined)
+ # GWASz[unique_elements] = GWASz[unique_elements] + np.random.normal(
+ # loc=0, scale=scale_noise, size=len(unique_elements)
+ # )
+ indexes_causal = indexes[
+ np.random.choice(np.arange(len(indexes)), size=1, replace=False)
+ ]
+ x_tmp = U[indexes_causal, :]
+ x_tmp = np.abs(x_tmp)
+ x_tmp = x_tmp.flatten()
+ x_tmp[indexes_causal] = 0
+ ind_tmp = np.where(x_tmp > 0.5)
+ ind_tmp = ind_tmp[0]
+ if len(ind_tmp) >= 2:
+ indexes_noise = ind_tmp[
+ np.random.choice(np.arange(len(ind_tmp)), size=2, replace=False)
+ ]
+ else:
+ indexes_noise = np.random.choice(M, size=2, replace=False)
+ GWASz[indexes_noise] = GWASz[indexes_noise] + np.random.normal(
+ loc=0, scale=scale_noise, size=len(indexes_noise)
+ )
+
+ GWASp = chi2.sf(GWASz**2, df=1) # convert Z to Pval
+
+ return {
+ "Z": GWASz.flatten(),
+ "P": GWASp.flatten(),
+ "Tau": Tau,
+ "indexes": indexes,
+ }