-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fady/chains #9
Merged
Merged
Fady/chains #9
Changes from 8 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
cf44434
starting chains
fadynakhla fa89597
progress
fadynakhla 70c0b85
adding initial prompts
fadynakhla 9bcb112
WIP - but getting closer
fadynakhla 0c03edf
updates
fadynakhla 4266915
removing init and adding classmethod
fadynakhla 4acbdd6
finally matching chain runs
fadynakhla 624faaf
improving prompt
fadynakhla abfefd7
Merge branch 'master' into fady/chains
fadynakhla b349f61
lock no update
fadynakhla 8c1c203
zero temp
fadynakhla File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
from typing import Any, Dict, List, Optional | ||
import asyncio | ||
import xml.etree.ElementTree as ET | ||
|
||
import pydantic | ||
import loguru | ||
from langchain import Anthropic | ||
from langchain.chains.base import Chain | ||
from langchain.chains import LLMChain, StuffDocumentsChain | ||
from langchain.llms.base import LLM | ||
from langchain.callbacks.manager import CallbackManagerForChainRun | ||
from langchain.prompts import PromptTemplate | ||
from langchain.docstore.document import Document | ||
from langchain.vectorstores.base import VectorStoreRetriever | ||
from langchain.schema.output_parser import BaseOutputParser | ||
|
||
from dr_claude.retrieval.retriever import HuggingFAISS | ||
from dr_claude.retrieval.embeddings import HuggingFaceEncoderEmbeddingsConfig | ||
|
||
|
||
logger = loguru.logger | ||
|
||
|
||
class Symptom(pydantic.BaseModel): | ||
symptom: str | ||
present: bool | ||
input_documents: Optional[List[Document]] = None | ||
|
||
|
||
class SymptomList(pydantic.BaseModel): | ||
symptoms: List[Symptom] | ||
|
||
|
||
class XmlOutputParser(BaseOutputParser[str]): | ||
"""OutputParser that parses LLMResult into the top likely string..""" | ||
|
||
@property | ||
def lc_serializable(self) -> bool: | ||
"""Whether the class LangChain serializable.""" | ||
return True | ||
|
||
@property | ||
def _type(self) -> str: | ||
"""Return the output parser type for serialization.""" | ||
return "default" | ||
|
||
def parse(self, text: str) -> str: | ||
"""Returns the input text with no changes.""" | ||
return parse_xml_line(text.strip()) | ||
|
||
|
||
class MatchingChain(Chain): | ||
|
||
symptom_extract_chain: LLMChain | ||
stuff_retrievals_match_chain: StuffDocumentsChain | ||
retriever: VectorStoreRetriever | ||
|
||
def _call( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Dict[str, Any]: | ||
raw_symptom_extract = self.symptom_extract_chain(inputs) | ||
symptom_list = parse_raw_extract(raw_symptom_extract["text"]) | ||
for symptom in symptom_list.symptoms: # suboptimal but fine for now | ||
symptom.input_documents = self.retriever.get_relevant_documents(symptom.symptom) | ||
logger.info(f"Retrieved {len(symptom.input_documents)} documents for {symptom.symptom}") | ||
logger.debug(f"Retrieved documents: {symptom.input_documents}") | ||
return self.run_matching_batch(symptom_list) | ||
|
||
async def _acall( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Dict[str, Any]: | ||
raw_symptom_extract = await self.symptom_extract_chain.acall(inputs) | ||
symptom_list = parse_raw_extract(raw_symptom_extract["text"]) | ||
for symptom in symptom_list.symptoms: # suboptimal but fine for now | ||
symptom.retrievals = await self.retriever.aget_relevant_documents(symptom.symptom) | ||
return self.run_matching_batch(symptom_list) | ||
|
||
def run_matching_batch(self, symptom_list: SymptomList) -> List[Dict[str, Any]]: | ||
|
||
async def run_batched(symptom_list: SymptomList) -> List[Dict[str, Any]]: | ||
tasks = [] | ||
for symptom in symptom_list.symptoms: | ||
output = self.stuff_retrievals_match_chain.acall(dict(symptom)) | ||
tasks.append(output) | ||
return await asyncio.gather(*tasks) | ||
|
||
return asyncio.run(run_batched(symptom_list)) | ||
|
||
# def _validate_outputs(self, outputs: List[Dict[str, Any]]) -> None: | ||
# for output in outputs: | ||
# super()._validate_outputs(output) | ||
|
||
def prep_outputs( | ||
self, | ||
inputs: Dict[str, str], | ||
outputs: List[Dict[str, str]], | ||
return_only_outputs: bool = False, | ||
) -> Dict[str, str]: | ||
new_outputs = [] | ||
for output in outputs: | ||
new_output = super().prep_outputs(inputs, output) | ||
new_outputs.append(new_output) | ||
return new_outputs | ||
|
||
@property | ||
def input_keys(self) -> List[str]: | ||
return self.symptom_extract_chain.input_keys | ||
|
||
@property | ||
def output_keys(self) -> List[str]: | ||
return ["match", "present"] | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: LLM, | ||
symptom_extract_prompt: PromptTemplate, | ||
symptom_match_prompt: PromptTemplate, | ||
retrieval_config: HuggingFaceEncoderEmbeddingsConfig, | ||
texts: List[str], | ||
) -> "MatchingChain": | ||
symptom_extract_chain = LLMChain( | ||
llm=llm, | ||
prompt=symptom_extract_prompt, | ||
) | ||
symptom_match_chain = LLMChain( | ||
llm=llm, | ||
prompt=symptom_match_prompt, | ||
output_parser=XmlOutputParser(), | ||
) | ||
stuff_retrievals_match_chain = StuffDocumentsChain( | ||
llm_chain=symptom_match_chain, | ||
document_variable_name="retrievals", | ||
verbose=True, | ||
callbacks=[], | ||
output_key="match", | ||
) | ||
vectorstore = HuggingFAISS.from_model_config_and_texts(texts, retrieval_config) | ||
retriever = vectorstore.as_retriever() | ||
return cls( | ||
symptom_extract_chain=symptom_extract_chain, | ||
stuff_retrievals_match_chain=stuff_retrievals_match_chain, | ||
retriever=retriever, | ||
) | ||
|
||
@classmethod | ||
def from_anthropic( | ||
cls, | ||
symptom_extract_prompt: PromptTemplate, | ||
symptom_match_prompt: PromptTemplate, | ||
retrieval_config: HuggingFaceEncoderEmbeddingsConfig, | ||
texts: List[str], | ||
) -> "MatchingChain": | ||
anthropic = Anthropic( | ||
temperature=0.1, | ||
verbose=True, | ||
) | ||
return cls.from_llm( | ||
llm=anthropic, | ||
symptom_extract_prompt=symptom_extract_prompt, | ||
symptom_match_prompt=symptom_match_prompt, | ||
retrieval_config=retrieval_config, | ||
texts=texts, | ||
) | ||
|
||
|
||
def parse_xml_line(line: str) -> str: | ||
root = ET.fromstring(line) | ||
return root.text | ||
|
||
|
||
def parse_raw_extract(text: str) -> SymptomList: | ||
symptom_strings = text.strip().split("\n") | ||
symptoms = [] | ||
logger.debug(f"Raw symptom strings: {symptom_strings}") | ||
for symptom_string in symptom_strings: | ||
logger.debug(f"Single line response: {symptom_string}") | ||
symptom_string = parse_xml_line(symptom_string) | ||
name, present = symptom_string.split(":") | ||
symptom = Symptom(symptom=name.strip(), present=present.strip() == "yes") | ||
symptoms.append(symptom) | ||
return SymptomList(symptoms=symptoms) | ||
|
||
|
||
if __name__ == "__main__": | ||
from dr_claude.chains import prompts | ||
|
||
chain = MatchingChain.from_anthropic( | ||
symptom_extract_prompt=prompts.SYMPTOM_EXTRACT_PROMPT, | ||
symptom_match_prompt=prompts.SYMPTOM_MATCH_PROMPT, | ||
retrieval_config=HuggingFaceEncoderEmbeddingsConfig( | ||
model_name_or_path="bert-base-uncased", | ||
device="cpu", | ||
), | ||
texts=["fever", "cough", "headache", "sore throat", "runny nose"] | ||
) | ||
inputs = { | ||
"question": "Do you have a fever?", | ||
"response": "yes and I have a headache as well", | ||
} | ||
print(chain(inputs)) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. empty file |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from langchain.prompts import PromptTemplate | ||
|
||
|
||
_symptom_extract_template = """Given the following conversation: | ||
Question: {question} | ||
Response: {response} | ||
|
||
Please write out the medical symptoms that appear as well as whether they are present. | ||
|
||
Your response should be in the following format please do not include any other text: | ||
<symptom> symptom1 : yes </symptom> | ||
<symptom> symptom2 : no </symptom> | ||
""" | ||
|
||
_symptom_match_template = """Given the symptom: {symptom} which of the following retrievals is the best match? | ||
Retrievals: | ||
{retrievals} | ||
|
||
Select only one and write it below in the following formatt: | ||
<match> match </match> | ||
|
||
Remember, do not include any other text and ensure your choice is in the provided retrievals. | ||
""" | ||
|
||
|
||
SYMPTOM_EXTRACT_PROMPT = PromptTemplate.from_template(_symptom_extract_template) | ||
SYMPTOM_MATCH_PROMPT = PromptTemplate.from_template(_symptom_match_template) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for the 0.1?