diff --git a/benchmark/data/benchmark_api_calling_data.yaml b/benchmark/data/benchmark_api_calling_data.yaml index 3cca6a0a..7dd8e04f 100644 --- a/benchmark/data/benchmark_api_calling_data.yaml +++ b/benchmark/data/benchmark_api_calling_data.yaml @@ -93,9 +93,10 @@ api_calling: explicit_variable_names: "plot the tsne embeddding of the data colored by n_genes_by_counts." expected: parts_of_query: ["sc.pl.tsne\\(", "n_genes_by_counts", "\\)"] + parts_of_query: ["sc.pl.tsne\\(", "n_genes_by_counts", "\\)"] - case: anndata:read:h5ad input: prompt: explicit_variable_names: "read test.h5ad into an anndata object." expected: - parts_of_query: ["anndata.read_h5ad\\(", "filename=test.h5ad", "\\)"] + parts_of_query: ["anndata.read_h5ad\\(", "filename=test.h5ad", "\\)"] \ No newline at end of file diff --git a/biochatter/api_agent/__init__.py b/biochatter/api_agent/__init__.py index 0506e22c..9e6f9970 100644 --- a/biochatter/api_agent/__init__.py +++ b/biochatter/api_agent/__init__.py @@ -9,6 +9,7 @@ BlastQueryParameters, ) from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder +from .scanpy_tl import ScanpyTLQueryBuilder, ScanpyTLQueryFetcher, ScanpyTLQueryInterpreter from .scanpy_pl import ScanpyPlQueryBuilder from .formatters import format_as_rest_call, format_as_python_call @@ -32,7 +33,10 @@ "BioToolsInterpreter", "BioToolsQueryBuilder", "APIAgent", + "ScanpyTLQueryBuilder", + "ScanpyTLQueryFetcher", + "ScanpyTLQueryInterpreter", "ScanpyPlQueryBuilder", "format_as_rest_call", "format_as_python_call", -] \ No newline at end of file +] diff --git a/biochatter/api_agent/generate_pydantic_classes_from_module.py b/biochatter/api_agent/generate_pydantic_classes_from_module.py new file mode 100644 index 00000000..5a60f959 --- /dev/null +++ b/biochatter/api_agent/generate_pydantic_classes_from_module.py @@ -0,0 +1,77 @@ +import inspect +from typing import Any, Dict, Optional, Type +from types import ModuleType +from docstring_parser import parse +from langchain_core.pydantic_v1 import BaseModel, Field, create_model + +def generate_pydantic_classes(module: ModuleType) -> list[Type[BaseModel]]: + """ + Generate Pydantic classes for each callable (function/method) in a given module. + + Extracts parameters from docstrings using docstring-parser. Each generated class + has fields corresponding to the parameters of the function. If a parameter name + conflicts with BaseModel attributes, it is aliased. + + Parameters + ---------- + module : ModuleType + The Python module from which to extract functions and generate models. + + Returns + ------- + Dict[str, Type[BaseModel]] + A dictionary mapping function names to their corresponding Pydantic model classes. + """ + base_attributes = set(dir(BaseModel)) + classes_list = [] + + # Iterate over all callables in the module + for name, func in inspect.getmembers(module, inspect.isfunction): + # skip if method starts with _ + if name.startswith("_"): + continue + doc = inspect.getdoc(func) + if not doc: + # If no docstring, still create a model with no fields + TLParametersModel = create_model(f"{name}") + classes_list.append(TLParametersModel) + continue + + parsed_doc = parse(doc) + + # Collect parameter descriptions + param_info = {} + for p in parsed_doc.params: + if p.arg_name not in param_info: + param_info[p.arg_name] = p.description or "No description available." + + # Prepare fields for create_model + fields = {} + alias_map = {} + + for param_name, param_desc in param_info.items(): + field_kwargs = {"default": None, "description": param_desc} + field_name = param_name + + # Alias if conflicts with BaseModel attributes + if param_name in base_attributes: + aliased_name = param_name + "_param" + field_kwargs["alias"] = param_name + alias_map[aliased_name] = param_name + field_name = aliased_name + + # Without type info, default to Optional[str] + fields[field_name] = (Optional[str], Field(**field_kwargs)) + + # Dynamically create the model for this function + TLParametersModel = create_model(name, **fields) + classes_list.append(TLParametersModel) + + return classes_list + + +# Example usage: +#import scanpy as sc +#generated_classes = generate_pydantic_classes(sc.tl) +#for func in generated_classes: +# print(func.schema()) \ No newline at end of file diff --git a/biochatter/api_agent/scanpy_tl.py b/biochatter/api_agent/scanpy_tl.py new file mode 100644 index 00000000..64a06ee2 --- /dev/null +++ b/biochatter/api_agent/scanpy_tl.py @@ -0,0 +1,158 @@ +"""Module for interacting with the `scanpy` API for data tools (`tl`).""" + +from typing import TYPE_CHECKING + +from langchain_core.output_parsers import PydanticToolsParser +from langchain_core.pydantic_v1 import BaseModel + +from .abc import BaseQueryBuilder +from .generate_pydantic_classes_from_module import generate_pydantic_classes + +if TYPE_CHECKING: + from biochatter.llm_connect import Conversation + +SCANPY_QUERY_PROMPT = """ +You are a world class algorithm for creating queries in structured formats. Your task is to use the scanpy python package +to provide the user with the appropriate function call to answer their question. You focus on the scanpy.tl module, which has +the following overview: +Any transformation of the data matrix that is not *preprocessing*. In contrast to a *preprocessing* function, a *tool* usually adds an easily interpretable annotation to the data matrix, which can then be visualized with a corresponding plotting function. + +### Embeddings + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + pp.pca + tl.tsne + tl.umap + tl.draw_graph + tl.diffmap +``` + +Compute densities on embeddings. + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + tl.embedding_density +``` + +### Clustering and trajectory inference + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + tl.leiden + tl.louvain + tl.dendrogram + tl.dpt + tl.paga +``` + +### Data integration + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + tl.ingest +``` + +### Marker genes + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + tl.rank_genes_groups + tl.filter_rank_genes_groups + tl.marker_gene_overlap +``` + +### Gene scores, Cell cycle + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + tl.score_genes + tl.score_genes_cell_cycle +``` + +### Simulations + +```{eval-rst} +.. autosummary:: + :nosignatures: + :toctree: ../generated/ + + tl.sim + +``` +""" + + +class ScanpyTLQueryBuilder(BaseQueryBuilder): + """A class for building an ScanpyTLQuery object.""" + + def create_runnable( + self, + query_parameters: BaseModel, + conversation: "Conversation", + ): + pass + + def parameterise_query( + self, + question: str, + conversation: "Conversation", + generated_classes=None, # Allow external injection of classes + module=None, + ): + """Generate an ScanpyTLQuery object. + + Generate a ScanpyTLQuery object based on the given question, prompt, + and BioChatter conversation. Uses a Pydantic model to define the API + fields. Using langchains .bind_tools method to allow the LLM to parameterise + the function call, based on the functions available in thescanpy.tl module. + + Args: + ---- + question (str): The question to be answered. + + conversation: The conversation object used for parameterising the + BioToolsQuery. + + Returns: + ------- + BioToolsQueryParameters: the parameterised query object (Pydantic + model) + + """ + import scanpy as sc + + module = sc.tl + generated_classes = generate_pydantic_classes(module) + # Generate classes if not provided + if generated_classes is None: + generated_classes = generate_pydantic_classes(module) + + llm = conversation.chat + llm_with_tools = llm.bind_tools(generated_classes) + query = [ + ("system", "You're an expert data scientist"), + ("human", {question}), + ("system", "You're an expert data scientist"), + ("human", question), + ] + chain = llm_with_tools | PydanticToolsParser(tools=generated_classes) + return chain.invoke(query) diff --git a/pyproject.toml b/pyproject.toml index b60e0088..516f0c0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ langchain = "^0.2.5" openai = "^1.1.0" httpx = "0.27.2" pymupdf = "^1.22.3" -pymilvus = "2.2.8" +pymilvus = ">=2.2.8" nltk = "^3.8.1" redis = "^4.5.5" retry = "^0.9.2" @@ -53,6 +53,7 @@ rouge_score = "0.1.2" evaluate = "^0.4.1" pillow = ">=10.2,<11.0" pdf2image = "^1.16.0" +scanpy = { version = "^1.11.0", optional = true } langchain-community = "^0.2.5" langgraph = "^0.1.5" langchain-openai = "^0.1.14" @@ -62,6 +63,7 @@ importlib-metadata = "^8.0.0" colorcet = "^3.1.0" langchain-anthropic = "^0.1.22" anthropic = "^0.33.0" +docstring-parser = "^0.16.0" [tool.poetry.extras] streamlit = ["streamlit"] diff --git a/test/test_api_agent.py b/test/test_api_agent.py index f815fa14..516a425e 100644 --- a/test/test_api_agent.py +++ b/test/test_api_agent.py @@ -31,6 +31,7 @@ OncoKBQueryBuilder, OncoKBQueryParameters, ) +from biochatter.api_agent.scanpy_tl import ScanpyTLQueryBuilder from biochatter.api_agent.scanpy_pl import ( SCANPY_PL_QUERY_PROMPT, ScanpyPlQueryBuilder, @@ -438,50 +439,53 @@ def test_summarise_results(mock_prompt, mock_conversation, mock_chain): mock_chain.invoke.assert_called_once_with( {"input": {expected_summary_prompt}}, ) +class TestScanpyTLQueryBuilder: + @patch("biochatter.llm_connect.GptConversation") + def test_parameterise_query(self, mock_conversation): + # Arrange + question = "I want to mark mitochondrial genes of my adata object" + # Mock the list of Pydantic classes as a list of Mock objects + class MockTool1(BaseModel): + param1: str -class TestScanpyPlQueryBuilder: - @pytest.fixture() - def mock_create_runnable(self): - with patch( - "biochatter.api_agent.scanpy_pl.create_structured_output_runnable" - ) as mock: - mock_runnable = MagicMock() - mock.return_value = mock_runnable - yield mock_runnable + class MockTool2(BaseModel): + param2: int + mock_generated_classes = [MockTool1, MockTool2] + # Mock the conversation object and LLM + mock_conversation_instance = mock_conversation.return_value + mock_llm = MagicMock() + mock_conversation_instance.chat = mock_llm -class TestScanpyPlFetcher: - pass + # Mock the LLM with tools + mock_llm_with_tools = MagicMock() + mock_llm.bind_tools.return_value = mock_llm_with_tools + # Mock the chain and its invoke method + mock_chain = MagicMock() + mock_llm_with_tools.__or__.return_value = mock_chain + mock_result = {"parameters": {"key_added": "mt_genes"}} + mock_chain.invoke.return_value = mock_result -class TestScanpyPlInterpreter: - pass + # Act + builder = ScanpyTLQueryBuilder() + result = builder.parameterise_query( + question, + mock_conversation_instance, + generated_classes=mock_generated_classes + ) + # Assert + mock_llm.bind_tools.assert_called_once_with(mock_generated_classes) + mock_chain.invoke.assert_called_once_with([ + ("system", "You're an expert data scientist"), + ("human", question), + ]) + assert result == mock_result -class TestAnndataIOQueryBuilder: - @pytest.fixture - def mock_create_runnable(self): - with patch( - "biochatter.api_agent.anndata.AnnDataIOQueryBuilder.create_runnable", - ) as mock: - mock_runnable = MagicMock() - mock.return_value = mock_runnable - yield mock_runnable - def test_parameterise_query(self, mock_create_runnable): - # Arrange - query_builder = AnnDataIOQueryBuilder() - mock_conversation = MagicMock() - question = "read a .h5ad file into an anndata object." - mock_query_obj = MagicMock() - mock_create_runnable.invoke.return_value = mock_query_obj - # Act - result = query_builder.parameterise_query(question, mock_conversation) - # Assert - mock_create_runnable.invoke.assert_called_once_with(question) - assert result == mock_query_obj