Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the scanpy tool tl modules to API agent. #232

Merged
merged 17 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
db60467
add the tools `tl` modules to API agent __init__.py
bastienchassagnol Dec 10, 2024
6abb3a9
add scanpy_tl module with general description
bastienchassagnol Dec 10, 2024
9044563
Change pymilvus dependency in the pyproject.toml from the fixed versi…
bastienchassagnol Dec 11, 2024
a46df7a
api agent for scnapy tl using the generate_pydantic_class_from_module
mengerj Dec 11, 2024
d4f3184
generic method to generate pydantic classes for functions in a module.
mengerj Dec 11, 2024
7b4df80
working progress on QueryBuilder and its unit tests
mengerj Dec 11, 2024
93be844
Merge pull request #1 from mengerj/just_the_generic_function
bastienchassagnol Dec 11, 2024
0b5ea9d
Merge branch 'main' into dev/tl-bastien
bastienchassagnol Dec 11, 2024
40a7751
Merge pull request #2 from bastienchassagnol/dev/tl-bastien
bastienchassagnol Dec 11, 2024
f10839a
add in the benchmark a call to scanpy.pp to carry on a PCA with a giv…
bastienchassagnol Dec 11, 2024
90a9a75
Merge branch 'biohackathon3' into main
bastienchassagnol Dec 11, 2024
7e742df
Added mock test for ScanpyTLQueryBuilder (without module specification)
vd-dragan21 Dec 11, 2024
cb27ba2
Merge branch 'main' into mock_test
bastienchassagnol Dec 11, 2024
78b5d49
Merge pull request #3 from bastienchassagnol/mock_test
bastienchassagnol Dec 11, 2024
a1c20f9
remove irrelevant imports in scanpy.tl module
bastienchassagnol Dec 11, 2024
ecb3360
Merge branch 'main' of https://github.com/bastienchassagnol/biochatter
bastienchassagnol Dec 11, 2024
286bbd4
Merge branch 'biohackathon3' into main
bastienchassagnol Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmark/data/benchmark_api_calling_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ api_calling:
explicit_variable_names: "plot the tsne embeddding of the data colored by n_genes_by_counts."
expected:
parts_of_query: ["sc.pl.tsne\\(", "n_genes_by_counts", "\\)"]
parts_of_query: ["sc.pl.tsne\\(", "n_genes_by_counts", "\\)"]
- case: anndata:read:h5ad
input:
prompt:
explicit_variable_names: "read test.h5ad into an anndata object."
expected:
parts_of_query: ["anndata.read_h5ad\\(", "filename=test.h5ad", "\\)"]
parts_of_query: ["anndata.read_h5ad\\(", "filename=test.h5ad", "\\)"]
6 changes: 5 additions & 1 deletion biochatter/api_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
BlastQueryParameters,
)
from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder
from .scanpy_tl import ScanpyTLQueryBuilder, ScanpyTLQueryFetcher, ScanpyTLQueryInterpreter
from .scanpy_pl import ScanpyPlQueryBuilder
from .formatters import format_as_rest_call, format_as_python_call

Expand All @@ -32,7 +33,10 @@
"BioToolsInterpreter",
"BioToolsQueryBuilder",
"APIAgent",
"ScanpyTLQueryBuilder",
"ScanpyTLQueryFetcher",
"ScanpyTLQueryInterpreter",
"ScanpyPlQueryBuilder",
"format_as_rest_call",
"format_as_python_call",
]
]
77 changes: 77 additions & 0 deletions biochatter/api_agent/generate_pydantic_classes_from_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import inspect
from typing import Any, Dict, Optional, Type
from types import ModuleType
from docstring_parser import parse
from langchain_core.pydantic_v1 import BaseModel, Field, create_model

def generate_pydantic_classes(module: ModuleType) -> list[Type[BaseModel]]:
"""
Generate Pydantic classes for each callable (function/method) in a given module.

Extracts parameters from docstrings using docstring-parser. Each generated class
has fields corresponding to the parameters of the function. If a parameter name
conflicts with BaseModel attributes, it is aliased.

Parameters
----------
module : ModuleType
The Python module from which to extract functions and generate models.

Returns
-------
Dict[str, Type[BaseModel]]
A dictionary mapping function names to their corresponding Pydantic model classes.
"""
base_attributes = set(dir(BaseModel))
classes_list = []

# Iterate over all callables in the module
for name, func in inspect.getmembers(module, inspect.isfunction):
# skip if method starts with _
if name.startswith("_"):
continue
doc = inspect.getdoc(func)
if not doc:
# If no docstring, still create a model with no fields
TLParametersModel = create_model(f"{name}")
classes_list.append(TLParametersModel)
continue

parsed_doc = parse(doc)

# Collect parameter descriptions
param_info = {}
for p in parsed_doc.params:
if p.arg_name not in param_info:
param_info[p.arg_name] = p.description or "No description available."

# Prepare fields for create_model
fields = {}
alias_map = {}

for param_name, param_desc in param_info.items():
field_kwargs = {"default": None, "description": param_desc}
field_name = param_name

# Alias if conflicts with BaseModel attributes
if param_name in base_attributes:
aliased_name = param_name + "_param"
field_kwargs["alias"] = param_name
alias_map[aliased_name] = param_name
field_name = aliased_name

# Without type info, default to Optional[str]
fields[field_name] = (Optional[str], Field(**field_kwargs))

# Dynamically create the model for this function
TLParametersModel = create_model(name, **fields)
classes_list.append(TLParametersModel)

return classes_list


# Example usage:
#import scanpy as sc
#generated_classes = generate_pydantic_classes(sc.tl)
#for func in generated_classes:
# print(func.schema())
158 changes: 158 additions & 0 deletions biochatter/api_agent/scanpy_tl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Module for interacting with the `scanpy` API for data tools (`tl`)."""

from typing import TYPE_CHECKING

from langchain_core.output_parsers import PydanticToolsParser
from langchain_core.pydantic_v1 import BaseModel

from .abc import BaseQueryBuilder
from .generate_pydantic_classes_from_module import generate_pydantic_classes

if TYPE_CHECKING:
from biochatter.llm_connect import Conversation

SCANPY_QUERY_PROMPT = """
You are a world class algorithm for creating queries in structured formats. Your task is to use the scanpy python package
to provide the user with the appropriate function call to answer their question. You focus on the scanpy.tl module, which has
the following overview:
Any transformation of the data matrix that is not *preprocessing*. In contrast to a *preprocessing* function, a *tool* usually adds an easily interpretable annotation to the data matrix, which can then be visualized with a corresponding plotting function.

### Embeddings

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

pp.pca
tl.tsne
tl.umap
tl.draw_graph
tl.diffmap
```

Compute densities on embeddings.

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.embedding_density
```

### Clustering and trajectory inference

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.leiden
tl.louvain
tl.dendrogram
tl.dpt
tl.paga
```

### Data integration

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.ingest
```

### Marker genes

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.rank_genes_groups
tl.filter_rank_genes_groups
tl.marker_gene_overlap
```

### Gene scores, Cell cycle

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.score_genes
tl.score_genes_cell_cycle
```

### Simulations

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.sim

```
"""


class ScanpyTLQueryBuilder(BaseQueryBuilder):
"""A class for building an ScanpyTLQuery object."""

def create_runnable(
self,
query_parameters: BaseModel,
conversation: "Conversation",
):
pass

def parameterise_query(
self,
question: str,
conversation: "Conversation",
generated_classes=None, # Allow external injection of classes
module=None,
):
"""Generate an ScanpyTLQuery object.

Generate a ScanpyTLQuery object based on the given question, prompt,
and BioChatter conversation. Uses a Pydantic model to define the API
fields. Using langchains .bind_tools method to allow the LLM to parameterise
the function call, based on the functions available in thescanpy.tl module.

Args:
----
question (str): The question to be answered.

conversation: The conversation object used for parameterising the
BioToolsQuery.

Returns:
-------
BioToolsQueryParameters: the parameterised query object (Pydantic
model)

"""
import scanpy as sc

module = sc.tl
generated_classes = generate_pydantic_classes(module)
# Generate classes if not provided
if generated_classes is None:
generated_classes = generate_pydantic_classes(module)

llm = conversation.chat
llm_with_tools = llm.bind_tools(generated_classes)
query = [
("system", "You're an expert data scientist"),
("human", {question}),
("system", "You're an expert data scientist"),
("human", question),
]
chain = llm_with_tools | PydanticToolsParser(tools=generated_classes)
return chain.invoke(query)
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ langchain = "^0.2.5"
openai = "^1.1.0"
httpx = "0.27.2"
pymupdf = "^1.22.3"
pymilvus = "2.2.8"
pymilvus = ">=2.2.8"
nltk = "^3.8.1"
redis = "^4.5.5"
retry = "^0.9.2"
Expand All @@ -53,6 +53,7 @@ rouge_score = "0.1.2"
evaluate = "^0.4.1"
pillow = ">=10.2,<11.0"
pdf2image = "^1.16.0"
scanpy = { version = "^1.11.0", optional = true }
langchain-community = "^0.2.5"
langgraph = "^0.1.5"
langchain-openai = "^0.1.14"
Expand All @@ -62,6 +63,7 @@ importlib-metadata = "^8.0.0"
colorcet = "^3.1.0"
langchain-anthropic = "^0.1.22"
anthropic = "^0.33.0"
docstring-parser = "^0.16.0"

[tool.poetry.extras]
streamlit = ["streamlit"]
Expand Down
72 changes: 38 additions & 34 deletions test/test_api_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
OncoKBQueryBuilder,
OncoKBQueryParameters,
)
from biochatter.api_agent.scanpy_tl import ScanpyTLQueryBuilder
from biochatter.api_agent.scanpy_pl import (
SCANPY_PL_QUERY_PROMPT,
ScanpyPlQueryBuilder,
Expand Down Expand Up @@ -438,50 +439,53 @@ def test_summarise_results(mock_prompt, mock_conversation, mock_chain):
mock_chain.invoke.assert_called_once_with(
{"input": {expected_summary_prompt}},
)
class TestScanpyTLQueryBuilder:
@patch("biochatter.llm_connect.GptConversation")
def test_parameterise_query(self, mock_conversation):
# Arrange
question = "I want to mark mitochondrial genes of my adata object"

# Mock the list of Pydantic classes as a list of Mock objects
class MockTool1(BaseModel):
param1: str

class TestScanpyPlQueryBuilder:
@pytest.fixture()
def mock_create_runnable(self):
with patch(
"biochatter.api_agent.scanpy_pl.create_structured_output_runnable"
) as mock:
mock_runnable = MagicMock()
mock.return_value = mock_runnable
yield mock_runnable
class MockTool2(BaseModel):
param2: int

mock_generated_classes = [MockTool1, MockTool2]

# Mock the conversation object and LLM
mock_conversation_instance = mock_conversation.return_value
mock_llm = MagicMock()
mock_conversation_instance.chat = mock_llm

class TestScanpyPlFetcher:
pass
# Mock the LLM with tools
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools

# Mock the chain and its invoke method
mock_chain = MagicMock()
mock_llm_with_tools.__or__.return_value = mock_chain
mock_result = {"parameters": {"key_added": "mt_genes"}}
mock_chain.invoke.return_value = mock_result

class TestScanpyPlInterpreter:
pass
# Act
builder = ScanpyTLQueryBuilder()
result = builder.parameterise_query(
question,
mock_conversation_instance,
generated_classes=mock_generated_classes
)

# Assert
mock_llm.bind_tools.assert_called_once_with(mock_generated_classes)
mock_chain.invoke.assert_called_once_with([
("system", "You're an expert data scientist"),
("human", question),
])
assert result == mock_result

class TestAnndataIOQueryBuilder:
@pytest.fixture
def mock_create_runnable(self):
with patch(
"biochatter.api_agent.anndata.AnnDataIOQueryBuilder.create_runnable",
) as mock:
mock_runnable = MagicMock()
mock.return_value = mock_runnable
yield mock_runnable

def test_parameterise_query(self, mock_create_runnable):
# Arrange
query_builder = AnnDataIOQueryBuilder()
mock_conversation = MagicMock()
question = "read a .h5ad file into an anndata object."

mock_query_obj = MagicMock()
mock_create_runnable.invoke.return_value = mock_query_obj

# Act
result = query_builder.parameterise_query(question, mock_conversation)

# Assert
mock_create_runnable.invoke.assert_called_once_with(question)
assert result == mock_query_obj