From 9893341e643abe797376cf6a9c3af9fb6f13928a Mon Sep 17 00:00:00 2001 From: mengerj Date: Thu, 12 Dec 2024 15:18:25 +0100 Subject: [PATCH 1/4] include BaseAPIModel --- .../generate_pydantic_classes_from_module.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/biochatter/api_agent/generate_pydantic_classes_from_module.py b/biochatter/api_agent/generate_pydantic_classes_from_module.py index ebaf3381..e461be34 100644 --- a/biochatter/api_agent/generate_pydantic_classes_from_module.py +++ b/biochatter/api_agent/generate_pydantic_classes_from_module.py @@ -12,9 +12,10 @@ from typing import Any from docstring_parser import parse -from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.pydantic_v1 import Field, create_model +from biochatter.api_agent.abc import BaseAPIModel -def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]: +def generate_pydantic_classes(module: ModuleType) -> list[type[BaseAPIModel]]: """Generate Pydantic classes for each callable. For each callable (function/method) in a given module. Extracts parameters @@ -43,7 +44,7 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]: required. """ - base_attributes = set(dir(BaseModel)) + base_attributes = set(dir(BaseAPIModel)) classes_list = [] for name, func in inspect.getmembers(module, inspect.isfunction): @@ -111,7 +112,9 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]: # Create the Pydantic model tl_parameters_model = create_model( name, - **fields) + **fields, + __base__=BaseAPIModel + ) classes_list.append(tl_parameters_model) return classes_list @@ -120,4 +123,4 @@ def generate_pydantic_classes(module: ModuleType) -> list[type[BaseModel]]: #import scanpy as sc #generated_classes = generate_pydantic_classes(sc.tl) #for func in generated_classes: -# print(func.schema()) +#print(func.model_json_schema()) From a97e9857ce3ca0faab66bf86660971c3725fb7c2 Mon Sep 17 00:00:00 2001 From: mengerj Date: Thu, 12 Dec 2024 15:19:20 +0100 Subject: [PATCH 2/4] renamed anndata module to anndata_agent due to conflicts --- biochatter/api_agent/__init__.py | 2 +- biochatter/api_agent/{anndata.py => anndata_agent.py} | 2 +- test/test_api_agent.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename biochatter/api_agent/{anndata.py => anndata_agent.py} (99%) diff --git a/biochatter/api_agent/__init__.py b/biochatter/api_agent/__init__.py index 31ed011a..258361f8 100644 --- a/biochatter/api_agent/__init__.py +++ b/biochatter/api_agent/__init__.py @@ -1,5 +1,5 @@ from .abc import BaseFetcher, BaseInterpreter, BaseQueryBuilder -from .anndata import AnnDataIOQueryBuilder, ReadCSV, ReadExcel, ReadH5AD, ReadHDF, ReadLoom, ReadMTX, ReadText, ReadZarr +from .anndata_agent import AnnDataIOQueryBuilder, ReadCSV, ReadExcel, ReadH5AD, ReadHDF, ReadLoom, ReadMTX, ReadText, ReadZarr from .api_agent import APIAgent from .bio_tools import BioToolsFetcher, BioToolsInterpreter, BioToolsQueryBuilder from .blast import ( diff --git a/biochatter/api_agent/anndata.py b/biochatter/api_agent/anndata_agent.py similarity index 99% rename from biochatter/api_agent/anndata.py rename to biochatter/api_agent/anndata_agent.py index 109d34f0..fd148ddc 100644 --- a/biochatter/api_agent/anndata.py +++ b/biochatter/api_agent/anndata_agent.py @@ -15,7 +15,7 @@ # from langchain_core.pydantic_v1 import BaseModel, Field from biochatter.llm_connect import Conversation -from .abc import BaseAPIModel, BaseQueryBuilder +from biochatter.api_agent.abc import BaseAPIModel, BaseQueryBuilder if TYPE_CHECKING: from biochatter.llm_connect import Conversation diff --git a/test/test_api_agent.py b/test/test_api_agent.py index ad23def5..e77baf4e 100644 --- a/test/test_api_agent.py +++ b/test/test_api_agent.py @@ -11,7 +11,7 @@ BaseInterpreter, BaseQueryBuilder, ) -from biochatter.api_agent.anndata import AnnDataIOQueryBuilder +from biochatter.api_agent.anndata_agent import AnnDataIOQueryBuilder from biochatter.api_agent.api_agent import APIAgent from biochatter.api_agent.blast import ( BLAST_QUERY_PROMPT, From 7998ea798e335c653109b406c3b969b2b92e233e Mon Sep 17 00:00:00 2001 From: mengerj Date: Thu, 12 Dec 2024 16:47:17 +0100 Subject: [PATCH 3/4] changing how pydantic classes are defined manually, alinging with automatic apporach --- biochatter/api_agent/__init__.py | 2 +- biochatter/api_agent/abc.py | 22 +-- biochatter/api_agent/anndata_agent.py | 190 ++++++++++---------------- test/test_api_agent.py | 2 +- 4 files changed, 85 insertions(+), 131 deletions(-) diff --git a/biochatter/api_agent/__init__.py b/biochatter/api_agent/__init__.py index 258361f8..e98155ca 100644 --- a/biochatter/api_agent/__init__.py +++ b/biochatter/api_agent/__init__.py @@ -1,5 +1,5 @@ from .abc import BaseFetcher, BaseInterpreter, BaseQueryBuilder -from .anndata_agent import AnnDataIOQueryBuilder, ReadCSV, ReadExcel, ReadH5AD, ReadHDF, ReadLoom, ReadMTX, ReadText, ReadZarr +from .anndata_agent import AnnDataIOQueryBuilder from .api_agent import APIAgent from .bio_tools import BioToolsFetcher, BioToolsInterpreter, BioToolsQueryBuilder from .blast import ( diff --git a/biochatter/api_agent/abc.py b/biochatter/api_agent/abc.py index 356237e4..d19e87ce 100644 --- a/biochatter/api_agent/abc.py +++ b/biochatter/api_agent/abc.py @@ -8,7 +8,7 @@ from collections.abc import Callable from langchain_core.prompts import ChatPromptTemplate -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, create_model, ConfigDict from biochatter.llm_connect import Conversation @@ -166,13 +166,13 @@ class BaseAPIModel(BaseModel): uuid: str | None = Field( None, description="Unique identifier for the model instance" ) - method_name: str = Field(..., description="Name of the method to be executed") - - class Config: - """BaseModel class configuration. - - Ensures the model can be extended without strict type checking on - inherited fields. - """ - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) + +class BaseTools(): + """Abstract base class for tools.""" + def make_pydantic_tools(self) -> list[BaseAPIModel]: + """Uses pydantics create_model to create a list of pydantic tools from a dictionary of parameters""" + tools = [] + for func_name, tool_params in self.tools_params.items(): + tools.append(create_model(func_name, **tool_params, __base__=BaseAPIModel)) + return tools \ No newline at end of file diff --git a/biochatter/api_agent/anndata_agent.py b/biochatter/api_agent/anndata_agent.py index fd148ddc..46089bf1 100644 --- a/biochatter/api_agent/anndata_agent.py +++ b/biochatter/api_agent/anndata_agent.py @@ -15,15 +15,15 @@ # from langchain_core.pydantic_v1 import BaseModel, Field from biochatter.llm_connect import Conversation -from biochatter.api_agent.abc import BaseAPIModel, BaseQueryBuilder +from biochatter.api_agent.abc import BaseAPIModel, BaseQueryBuilder, BaseTools if TYPE_CHECKING: from biochatter.llm_connect import Conversation from typing import Optional - -from pydantic import BaseModel, Field +# Careful as this is not the same as the langchain_core.pydantic_v1 +from pydantic import BaseModel, Field, create_model ANNDATA_IO_QUERY_PROMPT = """ You are a world class algorithm, computational biologist with world leading knowledge @@ -92,111 +92,73 @@ `io.read_umi_tools(filename[, dtype])` - Reads a gzipped condensed count matrix from UMI Tools. """ - - -class ReadH5AD(BaseAPIModel): - """Read .h5ad-formatted hdf5 file.""" - - method_name: str = Field(default="io.read_h5ad", description="NEVER CHANGE") - filename: str = Field(default="dummy.h5ad", description="Path to the .h5ad file") - backed: Optional[str] = Field( - default=None, description="Mode to access file: None, 'r' for read-only" - ) - as_sparse: Optional[str] = Field( - default=None, description="Convert to sparse format: 'csr', 'csc', or None" - ) - as_sparse_fmt: Optional[str] = Field( - default=None, description="Sparse format if converting, e.g., 'csr'" - ) - index_unique: Optional[str] = Field( - default=None, description="Make index unique by appending suffix if needed" - ) - - -class ReadZarr(BaseAPIModel): - """Read from a hierarchical Zarr array store.""" - - method_name: str = Field(default="io.read_zarr", description="NEVER CHANGE") - filename: str = Field( - default="placeholder.zarr", description="Path or URL to the Zarr store" - ) - - -class ReadCSV(BaseAPIModel): - """Read .csv file.""" - - method_name: str = Field(default="io.read_csv", description="NEVER CHANGE") - filename: str = Field( - default="placeholder.csv", description="Path to the .csv file" - ) - delimiter: Optional[str] = Field( - None, description="Delimiter used in the .csv file" - ) - first_column_names: Optional[bool] = Field( - None, description="Whether the first column contains names" - ) - - -class ReadExcel(BaseAPIModel): - """Read .xlsx (Excel) file.""" - - method_name: str = Field(default="io.read_excel", description="NEVER CHANGE") - filename: str = Field( - default="placeholder.xlsx", description="Path to the .xlsx file" - ) - sheet: Optional[str] = Field(None, description="Sheet name or index to read from") - dtype: Optional[str] = Field( - None, description="Data type for the resulting dataframe" - ) - - -class ReadHDF(BaseAPIModel): - """Read .h5 (hdf5) file.""" - - method_name: str = Field(default="io.read_hdf", description="NEVER CHANGE") - filename: str = Field(default="placeholder.h5", description="Path to the .h5 file") - key: Optional[str] = Field(None, description="Group key within the .h5 file") - - -class ReadLoom(BaseAPIModel): - """Read .loom-formatted hdf5 file.""" - - method_name: str = Field(default="io.read_loom", description="NEVER CHANGE") - filename: str = Field( - default="placeholder.loom", description="Path to the .loom file" - ) - sparse: Optional[bool] = Field(None, description="Whether to read data as sparse") - cleanup: Optional[bool] = Field(None, description="Clean up invalid entries") - X_name: Optional[str] = Field(None, description="Name to use for X matrix") - obs_names: Optional[str] = Field( - None, description="Column to use for observation names" - ) - var_names: Optional[str] = Field( - None, description="Column to use for variable names" - ) - - -class ReadMTX(BaseAPIModel): - """Read .mtx file.""" - - method_name: str = Field(default="io.read_mtx", description="NEVER CHANGE") - filename: str = Field( - default="placeholder.mtx", description="Path to the .mtx file" - ) - dtype: Optional[str] = Field(None, description="Data type for the matrix") - - -class ReadText(BaseAPIModel): - """Read .txt, .tab, .data (text) file.""" - - method_name: str = Field(default="io.read_text", description="NEVER CHANGE") - filename: str = Field( - default="placeholder.txt", description="Path to the text file" - ) - delimiter: Optional[str] = Field(None, description="Delimiter used in the file") - first_column_names: Optional[bool] = Field( - None, description="Whether the first column contains names" - ) +class Tools(BaseTools): + tools_params = {} + tools_params["io.read_h5ad"] = { + "filename": (str, Field(default="dummy.h5ad", description="Path to the .h5ad file")), + "backed": (Optional[str], Field(default=None, description="Mode to access file: None, 'r' for read-only")), + "as_sparse": (Optional[str], Field(default=None, description="Convert to sparse format: 'csr', 'csc', or None")), + "as_sparse_fmt": (Optional[str], Field(default=None, description="Sparse format if converting, e.g., 'csr'")), + "index_unique": (Optional[str], Field(default=None, description="Make index unique by appending suffix if needed")) + } + + # Parameters for io.read_zarr + tools_params["io.read_zarr"] = { + "method_name": (str, Field(default="io.read_zarr", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.zarr", description="Path or URL to the Zarr store")) + } + + # Parameters for io.read_csv + tools_params["io.read_csv"] = { + "method_name": (str, Field(default="io.read_csv", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.csv", description="Path to the .csv file")), + "delimiter": (Optional[str], Field(default=None, description="Delimiter used in the .csv file")), + "first_column_names": (Optional[bool], Field(default=None, description="Whether the first column contains names")) + } + + # Parameters for io.read_excel + tools_params["io.read_excel"] = { + "method_name": (str, Field(default="io.read_excel", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.xlsx", description="Path to the .xlsx file")), + "sheet": (Optional[str], Field(default=None, description="Sheet name or index to read from")), + "dtype": (Optional[str], Field(default=None, description="Data type for the resulting dataframe")) + } + + # Parameters for io.read_hdf + tools_params["io.read_hdf"] = { + "method_name": (str, Field(default="io.read_hdf", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.h5", description="Path to the .h5 file")), + "key": (Optional[str], Field(default=None, description="Group key within the .h5 file")) + } + + # Parameters for io.read_loom + tools_params["io.read_loom"] = { + "method_name": (str, Field(default="io.read_loom", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.loom", description="Path to the .loom file")), + "sparse": (Optional[bool], Field(default=None, description="Whether to read data as sparse")), + "cleanup": (Optional[bool], Field(default=None, description="Clean up invalid entries")), + "X_name": (Optional[str], Field(default=None, description="Name to use for X matrix")), + "obs_names": (Optional[str], Field(default=None, description="Column to use for observation names")), + "var_names": (Optional[str], Field(default=None, description="Column to use for variable names")) + } + + # Parameters for io.read_mtx + tools_params["io.read_mtx"] = { + "method_name": (str, Field(default="io.read_mtx", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.mtx", description="Path to the .mtx file")), + "dtype": (Optional[str], Field(default=None, description="Data type for the matrix")) + } + + # Parameters for io.read_text + tools_params["io.read_text"] = { + "method_name": (str, Field(default="io.read_text", description="NEVER CHANGE")), + "filename": (str, Field(default="placeholder.txt", description="Path to the text file")), + "delimiter": (Optional[str], Field(default=None, description="Delimiter used in the file")), + "first_column_names": (Optional[bool], Field(default=None, description="Whether the first column contains names")) + } + def __init__(self, tools_params: dict = tools_params): + super().__init__() + self.tools_params = tools_params class AnnDataIOQueryBuilder(BaseQueryBuilder): @@ -251,16 +213,8 @@ def parameterise_query( AnnDataIOQuery: the parameterised query object (Pydantic model) """ - tools = [ - ReadCSV, - ReadExcel, - ReadH5AD, - ReadHDF, - ReadLoom, - ReadMTX, - ReadText, - ReadZarr, - ] + tool_maker = Tools() + tools = tool_maker.make_pydantic_tools() runnable = self.create_runnable( conversation=conversation, query_parameters=tools ) diff --git a/test/test_api_agent.py b/test/test_api_agent.py index e77baf4e..fa3e522f 100644 --- a/test/test_api_agent.py +++ b/test/test_api_agent.py @@ -481,7 +481,7 @@ class TestAnndataIOQueryBuilder: @pytest.fixture def mock_create_runnable(self): with patch( - "biochatter.api_agent.anndata.AnnDataIOQueryBuilder.create_runnable", + "biochatter.api_agent.anndata_agent.AnnDataIOQueryBuilder.create_runnable", ) as mock: mock_runnable = MagicMock() mock.return_value = mock_runnable From 38595f6367265fcda36284b9c2c4aea00216228f Mon Sep 17 00:00:00 2001 From: slobentanzer Date: Thu, 12 Dec 2024 17:03:09 +0100 Subject: [PATCH 4/4] resolve import issue --- biochatter/api_agent/scanpy_tl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/biochatter/api_agent/scanpy_tl.py b/biochatter/api_agent/scanpy_tl.py index 64a06ee2..5b9fc82c 100644 --- a/biochatter/api_agent/scanpy_tl.py +++ b/biochatter/api_agent/scanpy_tl.py @@ -1,15 +1,12 @@ """Module for interacting with the `scanpy` API for data tools (`tl`).""" -from typing import TYPE_CHECKING - from langchain_core.output_parsers import PydanticToolsParser from langchain_core.pydantic_v1 import BaseModel from .abc import BaseQueryBuilder from .generate_pydantic_classes_from_module import generate_pydantic_classes +from biochatter.llm_connect import Conversation -if TYPE_CHECKING: - from biochatter.llm_connect import Conversation SCANPY_QUERY_PROMPT = """ You are a world class algorithm for creating queries in structured formats. Your task is to use the scanpy python package