Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor manual pydantics for scanpy pl agents #255

Open
wants to merge 20 commits into
base: biohackathon3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ serve.sh
.api_results/*
*.coverage
scaling_test.py
myvenv/
274 changes: 137 additions & 137 deletions benchmark/data/benchmark_api_calling_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ api_calling:
# explicit_variable_names: "Calculate UMAP embedding with minimum distance 0.3 and spread 1.0."
# expected:
# parts_of_query: ["sc.tl.umap\\(", "min_dist=0.3", "spread=1.0", "\\)"]
- case: scanpy:tl:leiden
input:
prompt:
explicit_variable_names: "Perform Leiden clustering on the data with resolution 0.5."
expected:
parts_of_query: ["sc.tl.leiden\\(", "resolution=0.5", "\\)"]
- case: scanpy:tl:umap
input:
prompt:
explicit_variable_names: "Calculate UMAP embedding with minimum distance 0.3 and spread 1.0."
expected:
parts_of_query: ["sc.tl.umap\\(", "min_dist=0.3", "spread=1.0", "\\)"]
# - case: scanpy:tl:leiden
# input:
# prompt:
# explicit_variable_names: "Perform Leiden clustering on the data with resolution 0.5."
# expected:
# parts_of_query: ["sc.tl.leiden\\(", "resolution=0.5", "\\)"]
# - case: scanpy:tl:umap
# input:
# prompt:
# explicit_variable_names: "Calculate UMAP embedding with minimum distance 0.3 and spread 1.0."
# expected:
# parts_of_query: ["sc.tl.umap\\(", "min_dist=0.3", "spread=1.0", "\\)"]
- case: scanpy:pl:scatter
input:
prompt:
Expand All @@ -113,128 +113,128 @@ api_calling:
"total_counts",
"\\)",
]
- case: scanpy:pl:pca
input:
prompt:
specific: "plot the PCA embedding colored by n_genes_by_counts and total_counts"
abbreviations: "plt the PC emb with n_genes_by_counts and total_counts as colors."
general_question: "How can I plot the PCA embedding with n_genes_by_counts and total_counts as colors?"
help_request: "Can you help me with plotting the PCA embedding with n_genes_by_counts and total_counts as colors?"
expected:
parts_of_query:
[
"sc.pl.pca\\(",
"adata=adata",
"n_genes_by_counts",
"total_counts",
"\\)",
]
- case: scanpy:pl:tsne
input:
prompt:
specific: "plot a tsne colored by n_genes_by_counts."
abbreviations: "tsne plt with n_genes_by_counts as colors."
general_question: "How can I plot a tsne with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a tsne with n_genes_by_counts as colors?"
expected:
parts_of_query:
["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:umap
input:
prompt:
specific: "plot a umap colored by number of n_genes_by_counts."
abbreviations: "umap plt with n_genes_by_counts as colors."
general_question: "How can I plot a umap with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a umap with n_genes_by_counts as colors?"
expected:
parts_of_query:
["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:draw_graph
input:
prompt:
specific: "plot a force-directed graph colored by n_genes_by_counts."
abbreviations: "force-directed plt with n_genes_by_counts as colors."
general_question: "How can I plot a force-directed graph with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting a force-directed graph with n_genes_by_counts as colors?"
expected:
parts_of_query:
["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: scanpy:pl:spatial
input:
prompt:
specific: "plot a the spatial data colored by n_genes_by_counts."
abbreviations: "spatial data plt with n_genes_by_counts as colors."
general_question: "How can I plot the spatial data with n_genes_by_counts as colors?"
help_request: "Can you help me with plotting the spatial data with n_genes_by_counts as colors?"
expected:
parts_of_query:
["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"]
- case: anndata:read:h5ad
input:
prompt:
explicit_variable_names: "Use AnnData to load the file test.h5ad into an AnnData object."
specific: "Load test.h5ad using AnnData."
abbreviation: "Read test.h5ad with AnnData."
general: "Open an H5AD file and load it as an AnnData object."
help_request: "How do I read test.h5ad into an AnnData object?"
expected:
parts_of_query: ["anndata.io.read_h5ad\\(", "filename=test.h5ad", "\\)"]
- case: anndata:read:csv
input:
prompt:
explicit_variable_names: "Use AnnData to load the file `test.csv` into an AnnData object."
specific: "Load test.csv using AnnData."
abbreviation: "Read test.csv with AnnData."
general: "Open a CSV file and load it as an AnnData object."
help_request: "How do I read test.csv into an AnnData object?"
expected:
parts_of_query: ["anndata.io.read_csv\\(", "filename=test.csv", "\\)"]
- case: anndata:concat:var
input:
prompt:
explicit_variable_names: "Concatenate adata1 and adata2 into a single AnnData object along the column axis with an inner join."
specific: "Join adata1 and adata2 by columns using AnnData with an inner join."
abbreviation: "Merge columns of adata1 and adata2 with AnnData."
general: "Combine two AnnData objects along the variable axis with an inner join."
help_request: "How do I concatenate adata1 and adata2 along columns?"
expected:
parts_of_query:
[
"anndata.concat\\(",
"\\[adata1, adata2\\]",
", axis='var', join='inner'",
"\\)",
]
- case: anndata:concat:obs
input:
prompt:
explicit_variable_names: "Concatenate adata1 and adata2 into a single AnnData object along the row axis with an outer join."
specific: "Join adata1 and adata2 by rows using AnnData with an outer join."
abbreviation: "Merge rows of adata1 and adata2 with AnnData."
general: "Combine two AnnData objects along the observation axis with an outer join."
help_request: "How do I concatenate adata1 and adata2 along rows?"
expected:
parts_of_query:
[
"anndata.concat\\(",
"\\[adata1, adata2\\]",
", axis='obs', join='outer'",
"\\)",
]
- case: anndata:map
input:
prompt:
explicit_variable_names: "Replace the values in the cell_type column of the obs attribute in adata. Replace type1 with new_type1, type2 with new_type2, and type3 with new_type3."
help_request: "How do I remap cell_type values to replace type1 with new_type1, type2 with new_type2, and type3 with new_type3. ?"
expected:
parts_of_query:
[
"adata.obs",
"\\[\"cell_type\"\\]",
"\\.map\\(",
"\\{\\s*\"type1\": \"new_type1\"",
"\\s*\"type2\": \"new_type2\"",
"\\s*\"type3\": \"new_type3\"",
"\\s*\\}\\)",
"\\.copy\\(\\)",
]
# - case: scanpy:pl:pca
# input:
# prompt:
# specific: "plot the PCA embedding colored by n_genes_by_counts and total_counts"
# abbreviations: "plt the PC emb with n_genes_by_counts and total_counts as colors."
# general_question: "How can I plot the PCA embedding with n_genes_by_counts and total_counts as colors?"
# help_request: "Can you help me with plotting the PCA embedding with n_genes_by_counts and total_counts as colors?"
# expected:
# parts_of_query:
# [
# "sc.pl.pca\\(",
# "adata=adata",
# "n_genes_by_counts",
# "total_counts",
# "\\)",
# ]
# - case: scanpy:pl:tsne
# input:
# prompt:
# specific: "plot a tsne colored by n_genes_by_counts."
# abbreviations: "tsne plt with n_genes_by_counts as colors."
# general_question: "How can I plot a tsne with n_genes_by_counts as colors?"
# help_request: "Can you help me with plotting a tsne with n_genes_by_counts as colors?"
# expected:
# parts_of_query:
# ["sc.pl.tsne\\(", "adata=adata", "n_genes_by_counts", "\\)"]
# - case: scanpy:pl:umap
# input:
# prompt:
# specific: "plot a umap colored by number of n_genes_by_counts."
# abbreviations: "umap plt with n_genes_by_counts as colors."
# general_question: "How can I plot a umap with n_genes_by_counts as colors?"
# help_request: "Can you help me with plotting a umap with n_genes_by_counts as colors?"
# expected:
# parts_of_query:
# ["sc.pl.umap\\(", "adata=adata", "n_genes_by_counts", "\\)"]
# - case: scanpy:pl:draw_graph
# input:
# prompt:
# specific: "plot a force-directed graph colored by n_genes_by_counts."
# abbreviations: "force-directed plt with n_genes_by_counts as colors."
# general_question: "How can I plot a force-directed graph with n_genes_by_counts as colors?"
# help_request: "Can you help me with plotting a force-directed graph with n_genes_by_counts as colors?"
# expected:
# parts_of_query:
# ["sc.pl.draw_graph\\(", "adata=adata", "n_genes_by_counts", "\\)"]
# - case: scanpy:pl:spatial
# input:
# prompt:
# specific: "plot a the spatial data colored by n_genes_by_counts."
# abbreviations: "spatial data plt with n_genes_by_counts as colors."
# general_question: "How can I plot the spatial data with n_genes_by_counts as colors?"
# help_request: "Can you help me with plotting the spatial data with n_genes_by_counts as colors?"
# expected:
# parts_of_query:
# ["sc.pl.spatial\\(", "adata=adata", "n_genes_by_counts", "\\)"]
# - case: anndata:read:h5ad
# input:
# prompt:
# explicit_variable_names: "Use AnnData to load the file test.h5ad into an AnnData object."
# specific: "Load test.h5ad using AnnData."
# abbreviation: "Read test.h5ad with AnnData."
# general: "Open an H5AD file and load it as an AnnData object."
# help_request: "How do I read test.h5ad into an AnnData object?"
# expected:
# parts_of_query: ["anndata.io.read_h5ad\\(", "filename=test.h5ad", "\\)"]
# - case: anndata:read:csv
# input:
# prompt:
# explicit_variable_names: "Use AnnData to load the file `test.csv` into an AnnData object."
# specific: "Load test.csv using AnnData."
# abbreviation: "Read test.csv with AnnData."
# general: "Open a CSV file and load it as an AnnData object."
# help_request: "How do I read test.csv into an AnnData object?"
# expected:
# parts_of_query: ["anndata.io.read_csv\\(", "filename=test.csv", "\\)"]
# - case: anndata:concat:var
# input:
# prompt:
# explicit_variable_names: "Concatenate adata1 and adata2 into a single AnnData object along the column axis with an inner join."
# specific: "Join adata1 and adata2 by columns using AnnData with an inner join."
# abbreviation: "Merge columns of adata1 and adata2 with AnnData."
# general: "Combine two AnnData objects along the variable axis with an inner join."
# help_request: "How do I concatenate adata1 and adata2 along columns?"
# expected:
# parts_of_query:
# [
# "anndata.concat\\(",
# "\\[adata1, adata2\\]",
# ", axis='var', join='inner'",
# "\\)",
# ]
# - case: anndata:concat:obs
# input:
# prompt:
# explicit_variable_names: "Concatenate adata1 and adata2 into a single AnnData object along the row axis with an outer join."
# specific: "Join adata1 and adata2 by rows using AnnData with an outer join."
# abbreviation: "Merge rows of adata1 and adata2 with AnnData."
# general: "Combine two AnnData objects along the observation axis with an outer join."
# help_request: "How do I concatenate adata1 and adata2 along rows?"
# expected:
# parts_of_query:
# [
# "anndata.concat\\(",
# "\\[adata1, adata2\\]",
# ", axis='obs', join='outer'",
# "\\)",
# ]
# - case: anndata:map
# input:
# prompt:
# explicit_variable_names: "Replace the values in the cell_type column of the obs attribute in adata. Replace type1 with new_type1, type2 with new_type2, and type3 with new_type3."
# help_request: "How do I remap cell_type values to replace type1 with new_type1, type2 with new_type2, and type3 with new_type3. ?"
# expected:
# parts_of_query:
# [
# "adata.obs",
# "\\[\"cell_type\"\\]",
# "\\.map\\(",
# "\\{\\s*\"type1\": \"new_type1\"",
# "\\s*\"type2\": \"new_type2\"",
# "\\s*\"type3\": \"new_type3\"",
# "\\s*\\}\\)",
# "\\.copy\\(\\)",
# ]
57 changes: 44 additions & 13 deletions biochatter/api_agent/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING

from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field, create_model, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, create_model

from biochatter.llm_connect import Conversation
if TYPE_CHECKING:
from biochatter.llm_connect import Conversation


class BaseQueryBuilder(ABC):
Expand Down Expand Up @@ -104,7 +106,7 @@ def fetch_results(
self,
query_models: list[BaseModel],
retries: int | None = 3,
):
) -> list[BaseModel]:
"""Fetch results by submitting a query.

Can implement a multi-step procedure if submitting and fetching are
Expand All @@ -116,6 +118,12 @@ def fetch_results(
query_models: list of Pydantic models describing the parameterised
queries

retries: The number of times to retry the query if it fails.

Returns:
-------
A list of Pydantic models containing the results of the queries.

"""


Expand Down Expand Up @@ -160,20 +168,43 @@ def summarise_results(
class BaseAPIModel(BaseModel):
"""A base class for all API models.

Includes default fields `uuid` and `method_name`.
Includes default fields `question_uuid` and `model_config`.
"""

uuid: str | None = Field(
None,
description="Unique identifier for the model instance",
question_uuid: str | None = Field(
None, description="Unique identifier for the question asked to the LLM",
)
model_config = ConfigDict(arbitrary_types_allowed=True)

class BaseTools():
"""Abstract base class for tools."""
class BaseTools:
"""Abstract base class for tools.

To build a class to parameterise a tool call, inherit a class from this
BaseTools class. You define the `tools_dict` and `tools_descriptions` in
the child class and set them as attributes. Then you can call the
`make_pydantic_tools` method to create the parameterisable Pydantic models
for the tool call. See `anndata_agent` or `scanpy_pl` agent for examples.
"""

def make_pydantic_tools(self) -> list[BaseAPIModel]:
"""Uses pydantics create_model to create a list of pydantic tools from a dictionary of parameters"""
"""Create parameterisable Pydantic models for the tool call.

Creates a list of Pydantic models for the tool call, based on the
`tool_params` and `tool_descriptions` attributes.
"""
tools = []
for func_name, tool_params in self.tools_params.items():
tools.append(create_model(func_name, **tool_params, __base__=BaseAPIModel))
return tools
tool_params = self.tool_params
tool_descriptions = self.tool_descriptions
# validate that keys are equal in tool_params and tool_descriptions
if set(tool_params) != set(tool_descriptions):
msg = "Keys in tools_params and tools_descriptions must be equal"
raise ValueError(msg)
for tool_name in tool_descriptions:
parameters = tool_params[tool_name]
tool_description = tool_descriptions[tool_name]
tools.append(
create_model(
tool_name, __doc__=tool_description, **parameters, __base__=BaseAPIModel,
),
)
return tools
Loading