diff --git a/docs/interface/BindingAPI.md b/docs/interface/BindingAPI.md index e990a11..9ca7660 100644 --- a/docs/interface/BindingAPI.md +++ b/docs/interface/BindingAPI.md @@ -1 +1 @@ -::: yeastdnnexplorer.interface.BindingAPI.BindingAPI +::: yeastdnnexplorer.interface.BindingAPI.BindingAPI \ No newline at end of file diff --git a/docs/interface/BindingConcatenatedAPI.md b/docs/interface/BindingConcatenatedAPI.md new file mode 100644 index 0000000..0b4432b --- /dev/null +++ b/docs/interface/BindingConcatenatedAPI.md @@ -0,0 +1 @@ +::: yeastdnnexplorer.interface.BindingConcatenatedAPI.BindingConcatenatedAPI diff --git a/docs/interface/DtoAPI.md b/docs/interface/DtoAPI.md new file mode 100644 index 0000000..f0ca1d5 --- /dev/null +++ b/docs/interface/DtoAPI.md @@ -0,0 +1 @@ +::: yeastdnnexplorer.interface.DtoAPI.DtoAPI diff --git a/docs/tutorials/database_interface.ipynb b/docs/tutorials/database_interface.ipynb index 79e41f7..972a94e 100644 --- a/docs/tutorials/database_interface.ipynb +++ b/docs/tutorials/database_interface.ipynb @@ -144,12 +144,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from yeastdnnexplorer.interface import *\n", - "import matplotlib.pyplot as plt" + "import matplotlib.pyplot as plt\n", + "import json\n", + "import pandas as pd" ] }, { @@ -2227,6 +2229,78 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DTO API\n", + "\n", + "There is now an endpoint to submit jobs to run\n", + "[dual_threshold_optimization](https://github.com/BrentLab/Dual_Threshold_Optimization/tree/dev).\n", + "Submitting a job, and retrieving the result, can be achieved with `submit` and `retrieve`,\n", + "similar to rankresponse. `read()` will retrieve the records in the `metadata`. The DTO\n", + "results are in the `results` field, stored as a json.\n", + "\n", + "```python\n", + "dto_api = DtoAPI()\n", + "\n", + "# submit a DTO job like this. Note -- this can take ~30 minutes to run. It will fail\n", + "# if that promotersetsig_id and expression_id already exist in the DTO table\n", + "args = [\n", + " {\n", + " \"promotersetsig_id\": \"40\",\n", + " \"expression_id\": \"77\",\n", + " \"pss_rename_metric_columns\": False,\n", + " \"pss_col1_ascending\": True,\n", + " \"pss_col2_ascending\": False,\n", + " \"pss_ranker_col1\": \"poisson_pval\",\n", + " \"pss_ranker_col2\": \"callingcards_enrichment\",\n", + " \"expression_col1_ascending\": False,\n", + " \"expression_ranker_col1\": \"effect\",\n", + " \"expression_ranker_col2\": None,\n", + " \"expression_ranker_col1_abs\": True,\n", + " \"n_permutations\": 1000,\n", + " \"n_threads\": 28,\n", + " }\n", + "]\n", + "\n", + "group_id = await dto_api.submit(post_dict=args)\n", + "\n", + "res = await dto_api.retrieve(group_id)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# retrieve a DTO result with `read` to get records from the database\n", + "dto_api = DtoAPI()\n", + "dto_records = await dto_api.read()\n", + "\n", + "print(f'dto metadata results: {dto_records.get(\"metadata\")}')\n", + "\n", + "dto_results_list = []\n", + "\n", + "for i, row in dto_records.get(\"metadata\").iterrows():\n", + " dto_results = json.loads(row.result.replace(\"'\", '\"'))\n", + "\n", + " dto_results['id'] = row.id\n", + " dto_results['promotersetsig_id'] = row.promotersetsig\n", + " dto_results['expression_id'] = row.expression\n", + " dto_results['regulator'] = row.regulator_symbol\n", + " dto_results['binding_source'] = row.binding_source\n", + " dto_results['expression_effect'] = row.expression_source\n", + "\n", + " dto_results_list.append(dto_results)\n", + "\n", + "dto_results_df = pd.DataFrame(dto_results_list)\n", + "\n", + "print(f'dto results: {dto_results_df}')" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/mkdocs.yml b/mkdocs.yml index 2c5a182..c5ad164 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -51,6 +51,7 @@ nav: - Visualizing and Testing Data Generation Methods: tutorials/visualizing_and_testing_data_generation_methods.ipynb - Generalized Logistic Models: tutorials/generalized_logistic_models.ipynb - LassoCV: tutorials/lassoCV.ipynb + - Interactor Modeling Workflow: tutorials/interactor_modeling_workflow.ipynb - API: - Data Loaders: - Synthetic Data Loader: data_loaders/synthetic_data_loader.md @@ -75,16 +76,18 @@ nav: - Records Only Classes: - interface/BindingManualQCAPI.md - interface/DataSourceAPI.md + - interface/DtoAPI.md - interface/ExpressionManualQCAPI.md - interface/FileFormatAPI.md - interface/GenomicFeatureAPI.md - interface/RegulatorAPI.md - Records and Files Classes: - - interface/BindingAPI.md - - interface/CallingCardsBackgroundAPI.md - - interface/ExpressionAPI.md - - interface/PromoterSetAPI.md - - interface/PromoterSetSigAPI.md + - BindingAPI: interface/BindingAPI.md + - BindingConcatenatedAPI: interface/BindingConcatenatedAPI.md + - CallingCardsBackgroundAPI: interface/CallingCardsBackgroundAPI.md + - ExpressionAPI: interface/ExpressionAPI.md + - PromoterSetAPI: interface/PromoterSetAPI.md + - PromoterSetSigAPI: interface/PromoterSetSigAPI.md - Developer Classes: - interface/AbstractAPI.md - interface/AbstractRecordsAndFilesAPI.md diff --git a/yeastdnnexplorer/interface/AbstractRecordsAndFilesAPI.py b/yeastdnnexplorer/interface/AbstractRecordsAndFilesAPI.py index c48f6b4..fe79726 100644 --- a/yeastdnnexplorer/interface/AbstractRecordsAndFilesAPI.py +++ b/yeastdnnexplorer/interface/AbstractRecordsAndFilesAPI.py @@ -136,6 +136,9 @@ async def read( :param retrieve_files: Boolean. Whether to retrieve the files associated with the records. Defaults to False. :type retrieve_files: bool + :param kwargs: The following kwargs are used by the read() function. Any + others are passed onto the callback function + - timeout: The timeout for the GET request. Defaults to 120. :return: The result of the callback function. :rtype: Any @@ -157,7 +160,8 @@ async def read( export_url = f"{self.url.rstrip('/')}/{self.export_url_suffix}" self.logger.debug("read() export_url: %s", export_url) - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(kwargs.pop("timeout", 120)) + async with aiohttp.ClientSession(timeout=timeout) as session: try: async with session.get( export_url, headers=self.header, params=self.params diff --git a/yeastdnnexplorer/interface/AbstractRecordsOnlyAPI.py b/yeastdnnexplorer/interface/AbstractRecordsOnlyAPI.py index 1ef1839..ce0386e 100644 --- a/yeastdnnexplorer/interface/AbstractRecordsOnlyAPI.py +++ b/yeastdnnexplorer/interface/AbstractRecordsOnlyAPI.py @@ -43,8 +43,9 @@ async def read( include `metadata`, `data`, and `cache` as parameters. :param export_url_suffix: The URL suffix for the export endpoint. This will return a response object with a csv file. - :param kwargs: Additional arguments to pass to the callback function. - :return: The result of the callback function. + :param kwargs: This can be used to pass "params" to the request to use in place + of `self.params`. If those are passed, they will be popped off and then + the remaining kwargs will be passed to the callback function """ if not callable(callback) or {"metadata", "data", "cache"} - set( @@ -66,7 +67,7 @@ async def read( async with session.get( export_url, headers=self.header, - params=self.params, + params=kwargs.pop("params", self.params), ) as response: response.raise_for_status() content = await response.content.read() diff --git a/yeastdnnexplorer/interface/BindingConcatenatedAPI.py b/yeastdnnexplorer/interface/BindingConcatenatedAPI.py new file mode 100644 index 0000000..9eefe67 --- /dev/null +++ b/yeastdnnexplorer/interface/BindingConcatenatedAPI.py @@ -0,0 +1,62 @@ +import os +from typing import Any + +import pandas as pd + +from yeastdnnexplorer.interface.AbstractRecordsAndFilesAPI import ( + AbstractRecordsAndFilesAPI, +) + + +class BindingConcatenatedAPI(AbstractRecordsAndFilesAPI): + """Class to interact with the BindingConcatenatedAPI endpoint.""" + + def __init__(self, **kwargs) -> None: + """ + Initialize the BindingConcatenatedAPI object. + + :param kwargs: parameters to pass through AbstractRecordsAndFilesAPI to + AbstractAPI. + + """ + valid_param_keys = kwargs.pop( + "valid_param_keys", + [ + "id", + "regulator", + "regulator_locus_tag", + "regulator_symbol", + "batch", + "replicate", + "source", + "strain", + "condition", + "lab", + "assay", + "workflow", + "data_usable", + ], + ) + + url = kwargs.pop("url", os.getenv("BINDINGCONCATENATED_URL", None)) + + super().__init__(url=url, valid_keys=valid_param_keys, **kwargs) + + def create(self, data: dict[str, Any], **kwargs) -> Any: + raise NotImplementedError("The BindingConcatenatedAPI does not support create.") + + def update(self, df: pd.DataFrame, **kwargs) -> Any: + raise NotImplementedError("The BindingConcatenatedAPI does not support update.") + + def delete(self, id: str, **kwargs) -> Any: + raise NotImplementedError("The BindingConcatenatedAPI does not support delete.") + + def submit(self, post_dict: dict[str, Any], **kwargs) -> Any: + raise NotImplementedError("The BindingConcatenatedAPI does not support submit.") + + def retrieve( + self, group_task_id: str, timeout: int, polling_interval: int, **kwargs + ) -> Any: + raise NotImplementedError( + "The BindingConcatenatedAPI does not support retrieve." + ) diff --git a/yeastdnnexplorer/interface/DtoAPI.py b/yeastdnnexplorer/interface/DtoAPI.py new file mode 100644 index 0000000..4cc1ee8 --- /dev/null +++ b/yeastdnnexplorer/interface/DtoAPI.py @@ -0,0 +1,197 @@ +import asyncio +import json +import os +import time +from typing import Any + +import aiohttp +import pandas as pd +import requests # type: ignore + +from yeastdnnexplorer.interface.AbstractRecordsOnlyAPI import AbstractRecordsOnlyAPI + + +class DtoAPI(AbstractRecordsOnlyAPI): + """ + A class to interact with the DTO API. + + Retrieves dto data from the database. + + """ + + def __init__(self, **kwargs) -> None: + """ + Initialize the DTO object. This will serve as an interface to the DTO endpoint + of both the database and the application cache. + + :param url: The URL of the DTO API + :param kwargs: Additional parameters to pass to AbstractAPI. + + """ + + self.bulk_update_url_suffix = kwargs.pop( + "bulk_update_url_suffix", "bulk-update" + ) + + super().__init__( + url=kwargs.pop("url", os.getenv("DTO_URL", "")), + **kwargs, + ) + + async def submit( + self, + post_dict: dict[str, Any], + **kwargs, + ) -> Any: + """ + Submit a DTO task to the DTO API. + + :param post_dict: The dictionary to submit to the DTO API. The typing needs to + be adjusted -- it can take a list of dictionaries to submit a batch. + :return: The group_task_id of the submitted task. + + """ + # make a post request with the post_dict to dto_url + dto_url = f"{self.url.rstrip('/')}/submit/" + self.logger.debug("dto_url: %s", dto_url) + + async with aiohttp.ClientSession() as session: + async with session.post( + dto_url, headers=self.header, json=post_dict + ) as response: + try: + response.raise_for_status() + except aiohttp.ClientResponseError as e: + self.logger.error( + "Failed to submit DTO task: Status %s, Reason %s", + e.status, + e.message, + ) + raise + result = await response.json() + try: + return result["group_task_id"] + except KeyError: + self.logger.error( + "Expected 'group_task_id' in response: %s", json.dumps(result) + ) + raise + + async def retrieve( + self, + group_task_id: str, + timeout: int = 300, + polling_interval: int = 2, + **kwargs, + ) -> dict[str, pd.DataFrame]: + """ + Periodically check the task status and retrieve the result when the task + completes. + + :param group_task_id: The task ID to retrieve results for. + :param timeout: The maximum time to wait for the task to complete (in seconds). + :param polling_interval: The time to wait between status checks (in seconds). + :return: Records from the DTO API of the successfully completed task. + + """ + # Start time for timeout check + start_time = time.time() + + # Task status URL + status_url = f"{self.url.rstrip('/')}/status/" + + while True: + async with aiohttp.ClientSession() as session: + # Send a GET request to check the task status + async with session.get( + status_url, + headers=self.header, + params={"group_task_id": group_task_id}, + ) as response: + response.raise_for_status() # Raise an error for bad status codes + status_response = await response.json() + + # Check if the task is complete + if status_response.get("status") == "SUCCESS": + + if error_tasks := status_response.get("error_tasks"): + self.logger.error( + f"Tasks {group_task_id} failed: {error_tasks}" + ) + if success_tasks := status_response.get("success_pks"): + params = {"id": ",".join(str(pk) for pk in success_tasks)} + return await self.read(params=params) + elif status_response.get("status") == "FAILURE": + raise Exception( + f"Task {group_task_id} failed: {status_response}" + ) + + # Check if we have reached the timeout + elapsed_time = time.time() - start_time + if elapsed_time > timeout: + raise TimeoutError( + f"Task {group_task_id} did not " + "complete within {timeout} seconds." + ) + + # Wait for the specified polling interval before checking again + await asyncio.sleep(polling_interval) + + def create(self, data: dict[str, Any], **kwargs) -> requests.Response: + raise NotImplementedError("The DTO does not support create.") + + def update(self, df: pd.DataFrame, **kwargs: Any) -> requests.Response: + """ + Update the records in the database. + + :param df: The DataFrame containing the records to update. + :type df: pd.DataFrame + :param kwargs: Additional fields to include in the payload. + :type kwargs: Any + :return: The response from the POST request. + :rtype: requests.Response + :raises requests.RequestException: If the request fails. + + """ + bulk_update_url = ( + f"{self.url.rstrip('/')}/{self.bulk_update_url_suffix.rstrip('/')}/" + ) + + self.logger.debug("bulk_update_url: %s", bulk_update_url) + + # Include additional fields in the payload if provided + payload = {"data": df.to_dict(orient="records")} + payload.update(kwargs) + + try: + response = requests.post( + bulk_update_url, + headers=self.header, + json=payload, + ) + response.raise_for_status() + return response + except requests.RequestException as e: + self.logger.error(f"Error in POST request: {e}") + raise + + def delete(self, id: str, **kwargs) -> Any: + """ + Delete a DTO record from the database. + + :param id: The ID of the DTO record to delete. + :return: A dictionary with a status message indicating success or failure. + + """ + # Include the Authorization header with the token + headers = kwargs.get("headers", {}) + headers["Authorization"] = f"Token {self.token}" + + # Make the DELETE request with the updated headers + response = requests.delete(f"{self.url}/{id}/", headers=headers, **kwargs) + + if response.status_code == 204: + return {"status": "success", "message": "DTO deleted successfully."} + + # Raise an error if the response indicates failure + response.raise_for_status() diff --git a/yeastdnnexplorer/interface/ExpressionAPI.py b/yeastdnnexplorer/interface/ExpressionAPI.py index a0c8e9b..23fb32e 100644 --- a/yeastdnnexplorer/interface/ExpressionAPI.py +++ b/yeastdnnexplorer/interface/ExpressionAPI.py @@ -39,6 +39,7 @@ def __init__(self, **kwargs) -> None: "workflow", "effect_colname", "pvalue_colname", + "preferred_replicate", ], ) diff --git a/yeastdnnexplorer/interface/ExpressionManualQCAPI.py b/yeastdnnexplorer/interface/ExpressionManualQCAPI.py index 1252a99..98a54e5 100644 --- a/yeastdnnexplorer/interface/ExpressionManualQCAPI.py +++ b/yeastdnnexplorer/interface/ExpressionManualQCAPI.py @@ -2,6 +2,7 @@ from typing import Any import pandas as pd +import requests # type: ignore from yeastdnnexplorer.interface.AbstractRecordsOnlyAPI import AbstractRecordsOnlyAPI @@ -44,13 +45,49 @@ def __init__(self, **kwargs): "`EXPRESSIONMANUALQC_URL` must be set", ) + self.bulk_update_url_suffix = kwargs.pop( + "bulk_update_url_suffix", "bulk-update" + ) + super().__init__(url=url, valid_keys=valid_param_keys, **kwargs) def create(self, data: dict[str, Any], **kwargs) -> Any: raise NotImplementedError("The ExpressionManualQCAPI does not support create.") - def update(self, df: pd.DataFrame, **kwargs) -> Any: - raise NotImplementedError("The ExpressionManualQCAPI does not support update.") + def update(self, df: pd.DataFrame, **kwargs: Any) -> requests.Response: + """ + Update the records in the database. + + :param df: The DataFrame containing the records to update. + :type df: pd.DataFrame + :param kwargs: Additional fields to include in the payload. + :type kwargs: Any + :return: The response from the POST request. + :rtype: requests.Response + :raises requests.RequestException: If the request fails. + + """ + bulk_update_url = ( + f"{self.url.rstrip('/')}/{self.bulk_update_url_suffix.rstrip('/')}/" + ) + + self.logger.debug("bulk_update_url: %s", bulk_update_url) + + # Include additional fields in the payload if provided + payload = {"data": df.to_dict(orient="records")} + payload.update(kwargs) + + try: + response = requests.post( + bulk_update_url, + headers=self.header, + json=payload, + ) + response.raise_for_status() + return response + except requests.RequestException as e: + self.logger.error(f"Error in POST request: {e}") + raise def delete(self, id: str, **kwargs) -> Any: raise NotImplementedError("The ExpressionManualQCAPI does not support delete.") diff --git a/yeastdnnexplorer/interface/PromoterSetSigAPI.py b/yeastdnnexplorer/interface/PromoterSetSigAPI.py index d61995d..6e59ee0 100644 --- a/yeastdnnexplorer/interface/PromoterSetSigAPI.py +++ b/yeastdnnexplorer/interface/PromoterSetSigAPI.py @@ -41,6 +41,7 @@ def __init__(self, **kwargs) -> None: "aggregated", "condition", "deduplicate", + "preferred_replicate", ], ) diff --git a/yeastdnnexplorer/interface/RankResponseAPI.py b/yeastdnnexplorer/interface/RankResponseAPI.py index d703c89..6b1a389 100644 --- a/yeastdnnexplorer/interface/RankResponseAPI.py +++ b/yeastdnnexplorer/interface/RankResponseAPI.py @@ -8,7 +8,7 @@ import aiohttp import pandas as pd -from requests import Response, post # type: ignore +from requests import Response, delete, post # type: ignore from requests_toolbelt import MultipartEncoder from yeastdnnexplorer.interface.AbstractRecordsAndFilesAPI import ( @@ -239,5 +239,48 @@ def create(self, data: dict[str, Any], **kwargs) -> Response: def update(self, df: pd.DataFrame, **kwargs) -> Any: raise NotImplementedError("The RankResponseAPI does not support update.") - def delete(self, id: str, **kwargs) -> Any: - raise NotImplementedError("The RankResponseAPI does not support delete.") + def delete(self, id: str = "", **kwargs) -> Any: + """ + Delete one or more records from the database. + + :param id: The ID of the record to delete. However, you can also pass in + `ids` as a list of IDs to delete multiple records. This is why `id` is optional. + If neither `id` nor `ids` is provided, a ValueError is raised. + + :return: A dictionary with a status message indicating success or failure. + + :raises ValueError: If neither `id` nor `ids` is provided. + + """ + # Include the Authorization header with the token + headers = kwargs.get("headers", {}) + headers["Authorization"] = f"Token {self.token}" + + ids = kwargs.pop("ids", str(id)) + + # Determine if it's a single ID or multiple + if isinstance(ids, str) and str != "": + # Single ID deletion for backward compatibility + response = delete(f"{self.url}/{ids}/", headers=headers, **kwargs) + elif isinstance(ids, list) and ids: + # Bulk delete with a list of IDs + response = delete( + f"{self.url}/delete/", + headers=headers, + json={"ids": ids}, # Send the list of IDs in the request body + **kwargs, + ) + else: + raise ValueError( + "No ID(s) provided for deletion. Either pass a single ID with " + "`id` or a list of IDs with `ids = [1,2, ...]" + ) + + if response.status_code in [200, 204]: + return { + "status": "success", + "message": "RankResponse(s) deleted successfully.", + } + + # Raise an error if the response indicates failure + response.raise_for_status() diff --git a/yeastdnnexplorer/interface/__init__.py b/yeastdnnexplorer/interface/__init__.py index eec2ebb..51be000 100644 --- a/yeastdnnexplorer/interface/__init__.py +++ b/yeastdnnexplorer/interface/__init__.py @@ -1,7 +1,9 @@ from .BindingAPI import BindingAPI +from .BindingConcatenatedAPI import BindingConcatenatedAPI from .BindingManualQCAPI import BindingManualQCAPI from .CallingCardsBackgroundAPI import CallingCardsBackgroundAPI from .DataSourceAPI import DataSourceAPI +from .DtoAPI import DtoAPI from .ExpressionAPI import ExpressionAPI from .ExpressionManualQCAPI import ExpressionManualQCAPI from .FileFormatAPI import FileFormatAPI @@ -12,15 +14,18 @@ from .rank_transforms import ( negative_log_transform_by_pvalue_and_enrichment, shifted_negative_log_ranks, + stable_rank, ) from .RankResponseAPI import RankResponseAPI from .RegulatorAPI import RegulatorAPI __all__ = [ "BindingAPI", + "BindingConcatenatedAPI", "BindingManualQCAPI", "CallingCardsBackgroundAPI", "DataSourceAPI", + "DtoAPI", "ExpressionAPI", "ExpressionManualQCAPI", "FileFormatAPI", @@ -31,5 +36,6 @@ "PromoterSetSigAPI", "RankResponseAPI", "RegulatorAPI", + "stable_rank", "shifted_negative_log_ranks", ] diff --git a/yeastdnnexplorer/interface/rank_transforms.py b/yeastdnnexplorer/interface/rank_transforms.py index 5ae2328..4634621 100644 --- a/yeastdnnexplorer/interface/rank_transforms.py +++ b/yeastdnnexplorer/interface/rank_transforms.py @@ -20,9 +20,7 @@ def shifted_negative_log_ranks(ranks: np.ndarray) -> np.ndarray: return -1 * np.log10(ranks) + log_max_rank -def negative_log_transform_by_pvalue_and_enrichment( - pvalue_vector: np.ndarray, enrichment_vector: np.ndarray -) -> np.ndarray: +def stable_rank(pvalue_vector: np.ndarray, enrichment_vector: np.ndarray) -> np.ndarray: """ Ranks data by primary_column, breaking ties based on secondary_column. The expected primary and secondary columns are 'pvalue' and 'enrichment', respectively. Then the @@ -88,4 +86,25 @@ def negative_log_transform_by_pvalue_and_enrichment( # Step 4: Final rank based on the adjusted primary ranks final_ranks = rankdata(adjusted_primary_rank, method="average") + return final_ranks + + +def negative_log_transform_by_pvalue_and_enrichment( + pvalue_vector: np.ndarray, enrichment_vector: np.ndarray +) -> np.ndarray: + """ + This calls the rank() function and then transforms the ranks to negative log10 + values and shifts to the right such that the lowest value (largest rank, least + important) is 0. + + :param pvalue_vector: A vector of pvalues + :param enrichment_vector: A vector of enrichment values corresponding to the pvalues + :return np.ndarray: A vector of negative log10 transformed ranks shifted such that + the lowest value is 0 and the highest value is log10(min_rank) + :raises ValueError: If the primary or secondary column is not numeric. + + """ + + final_ranks = stable_rank(pvalue_vector, enrichment_vector) + return shifted_negative_log_ranks(final_ranks) diff --git a/yeastdnnexplorer/utils/stable_rank.py b/yeastdnnexplorer/utils/stable_rank.py new file mode 100644 index 0000000..6512aaa --- /dev/null +++ b/yeastdnnexplorer/utils/stable_rank.py @@ -0,0 +1,48 @@ +import numpy as np +from scipy.stats import rankdata + + +def stable_rank( + col1: np.ndarray, + col2: np.ndarray | None = None, + col1_ascending: bool = True, + col2_ascending: bool = True, + method="average", +) -> np.ndarray: + """""" + # Validate inputs + if not np.issubdtype(col1.dtype, np.number): + raise ValueError("`col1` must be numeric") + if col2 is not None and not np.issubdtype(col2.dtype, np.number): + raise ValueError("`col2` must be numeric if provided") + + # Rank `col1` + col1_sorted = col1 if col1_ascending else -col1 + primary_rank = rankdata(col1_sorted, method="min") + + if col2 is None: + # If `col2` is not provided, return ranks based only on `col1` + return primary_rank + + # Handle ties using `col2` + adjusted_primary_rank = primary_rank.astype(float) + unique_ranks = np.unique(primary_rank) + + for unique_rank in unique_ranks: + tie_indices = np.where(primary_rank == unique_rank)[0] + + if len(tie_indices) > 1: # Adjust only in case of ties + col2_sorted = col2[tie_indices] if col2_ascending else -col2[tie_indices] + secondary_rank_within_ties = rankdata(col2_sorted, method=method) + + # Dynamically scale secondary ranks to prevent overlaps + max_secondary_rank = np.max(secondary_rank_within_ties) + scale_factor = 0.9 / max_secondary_rank + + adjusted_primary_rank[tie_indices] += ( + secondary_rank_within_ties * scale_factor + ) + + # Final rank + final_ranks = rankdata(adjusted_primary_rank, method=method) + return final_ranks