Skip to content

Commit

Permalink
Merge pull request #25 from fadynakhla/fady/pick-action
Browse files Browse the repository at this point in the history
Fady/pick action
  • Loading branch information
fadynakhla authored Jul 30, 2023
2 parents 7e16908 + be3233a commit 19807aa
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 83 deletions.
63 changes: 63 additions & 0 deletions dr_claude/chains/decision_claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Dict, List

import loguru
from langchain.chat_models import ChatAnthropic
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from dr_claude import datamodels


logger = loguru.logger


_decision_template = """You will be given the context of a patient through a list of positive and negative symptoms.
You will then be given a set of symptoms that an intelligent system has predicted are the next best questions to ask the patient.
Your job is to choose the best action.
Known patient state:
positive symptoms: {positive_symptoms}
negative symptoms: {negative_symptoms}
Symptoms to consider: {symptoms}
What is the the best symptom to ask the patient about?
Remember to ensure the chosen symptom exactly matches one of those you are asked to consider. Do not provide any other information or text.
Chosen Symptom:
"""

DECISION_PROMPT = PromptTemplate.from_template(_decision_template)


class DecisionClaude:
def __init__(self):
self.chain = get_decision_claude()

def __call__(self, actions: List[datamodels.Symptom], state):
inputs = self.get_action_picker_inputs(actions, state)
response = self.chain(inputs)
action = response["text"].strip()
logger.info(f"Chosen Action: {action}")
return action

def get_action_picker_inputs(
self, actions: List[datamodels.Symptom], state
) -> Dict[str, str]:
return {
"positive_symptoms": " | ".join(
[action.name for action in state.pertinent_pos]
),
"negative_symptoms": " | ".join(
[action.name for action in state.pertinent_neg]
),
"symptoms": " | ".join([action.name for action in actions]),
}


def get_decision_claude() -> LLMChain:
return LLMChain(
llm=ChatAnthropic(temperature=0.0, verbose=True),
prompt=DECISION_PROMPT,
verbose=True,
)
27 changes: 0 additions & 27 deletions dr_claude/chains/explain_yourself.py

This file was deleted.

6 changes: 5 additions & 1 deletion dr_claude/chains/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def from_anthropic(


def parse_xml_line(line: str) -> str:
root = ET.fromstring(line)
try:
root = ET.fromstring(line)
except ET.ParseError as e:
logger.error(f"Failed to parse XML line: {line}")
raise e
return root.text


Expand Down
4 changes: 2 additions & 2 deletions dr_claude/chains/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
Retrievals:
{retrievals}
Select only one and write it below in the following formatt:
Select only one and write it below in the following format:
<match> match </match>
Remember, do not include any other text and ensure your choice is in the provided retrievals.
Remember, do not include any other text, ensure your choice is in the provided retrievals, and follow the output format.
"""


Expand Down
53 changes: 53 additions & 0 deletions dr_claude/claude_mcts/multi_choice_mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import math
import random
import time
import heapq

from mcts import treeNode, mcts


class MultiChildMixin:
def search(self, initialState, top_k):
self.root = treeNode(initialState, None)

if self.limitType == "time":
timeLimit = time.time() + self.timeLimit / 1000
while time.time() < timeLimit:
self.executeRound()
else:
for i in range(self.searchLimit):
self.executeRound()

bestChild = self.getBestChild(self.root, 0, top_k)
return self.getAction(self.root, bestChild)

def getBestChild(self, node, explorationValue, top_k):
node_values = []
for i, child in enumerate(node.children.values()):
nodeValue = (
child.totalReward / child.numVisits
+ explorationValue
* math.sqrt(2 * math.log(node.numVisits) / child.numVisits)
)
# Use negative value because heapq is a min heap, i to break ties
heapq.heappush(node_values, (-nodeValue, i, child))
# Keep only the top_k node values
if len(node_values) > top_k:
heapq.heappop(node_values)

# Return the children associated with the top_k node values
top_k_nodes = [heapq.heappop(node_values)[2] for _ in range(len(node_values))]
# The nodes are popped in ascending order, so reverse the list
top_k_nodes.reverse()
return top_k_nodes

def getAction(self, root, bestChild):
nodes = []
for action, node in root.children.items():
if node in bestChild:
nodes.append(action)
return nodes


class MultiChoiceMCTS(MultiChildMixin, mcts):
...
4 changes: 3 additions & 1 deletion dr_claude/retrieval/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def encode(self, texts: Union[str, List[str]], pooling: str) -> torch.Tensor:
return embeddings

@classmethod
def from_config(cls, config: HuggingFaceEncoderEmbeddingsConfig) -> "HuggingFaceEncoderEmbeddings":
def from_config(
cls, config: HuggingFaceEncoderEmbeddingsConfig
) -> "HuggingFaceEncoderEmbeddings":
return cls(**config.dict())


Expand Down
11 changes: 8 additions & 3 deletions dr_claude/retrieval/retriever.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Dict, List, Optional
from transformers import AutoTokenizer, AutoModel
from langchain.vectorstores import FAISS
from langchain.vectorstores.utils import DistanceStrategy

from dr_claude.retrieval.embeddings import HuggingFaceEncoderEmbeddings, HuggingFaceEncoderEmbeddingsConfig
from dr_claude.retrieval.embeddings import (
HuggingFaceEncoderEmbeddings,
HuggingFaceEncoderEmbeddingsConfig,
)


class HuggingFAISS(FAISS):

@classmethod
def from_model_config_and_texts(
cls,
Expand All @@ -16,7 +19,9 @@ def from_model_config_and_texts(
ids: Optional[List[str]] = None,
) -> "HuggingFAISS":
embeddings = HuggingFaceEncoderEmbeddings.from_config(model_config)
return cls.from_texts(texts, embeddings, metadatas, ids)
return cls.from_texts(
texts, embeddings, metadatas, ids, distance_strategy=DistanceStrategy.COSINE
)


if __name__ == "__main__":
Expand Down
70 changes: 58 additions & 12 deletions dr_claude/runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Union, Dict
from typing import List, Union, Dict
import mcts
from loguru import logger

from dr_claude import kb_reading, datamodels, chaining_the_chains
from dr_claude.retrieval import retriever
from dr_claude.claude_mcts import action_states
from dr_claude.claude_mcts import action_states, multi_choice_mcts
from dr_claude.chains import decision_claude, doctor, matcher, patient, prompts


def a():
def main():
reader = kb_reading.CSVKnowledgeBaseReader("data/ClaudeKnowledgeBase.csv")
kb = reader.load_knowledge_base()
matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)
Expand All @@ -21,24 +22,43 @@ def a():
fever = symptom
state.pertinent_pos.update([fever])
rollout_policy = action_states.ArgMaxDiagnosisRolloutPolicy()
searcher = mcts.mcts(timeLimit=3000, rolloutPolicy=rollout_policy)
searcher = multi_choice_mcts.MultiChoiceMCTS(
timeLimit=3000, rolloutPolicy=rollout_policy
)

action_picker = decision_claude.DecisionClaude()
note = ("The patient has syncope, vertigo, nausea and is sweating",)
embedding_model_name = "/data/models/RoSaBERTa_large/"
# embedding_model_name = "bert-base-uncased"
# embedding_model_name = "/data/models/RoSaBERTa_large/"
embedding_model_name = "roberta-large"
retrieval_config = retriever.HuggingFaceEncoderEmbeddingsConfig(
model_name_or_path=embedding_model_name,
device="cpu",
)
chain_chainer = chaining_the_chains.ChainChainer(
matcher_chain = matcher.MatchingChain.from_anthropic(
symptom_extract_prompt=prompts.SYMPTOM_EXTRACT_PROMPT,
symptom_match_prompt=prompts.SYMPTOM_MATCH_PROMPT,
retrieval_config=retrieval_config,
symptoms=list(set(symptom_name_to_symptom)),
texts=list(set(symptom_name_to_symptom)),
)
doc_chain = doctor.get_doc_chain()
patient_chain = patient.get_patient_chain()
chain_chainer = chaining_the_chains.ChainChainer(
matcher_chain=matcher_chain,
doc_chain=doc_chain,
patient_chain=patient_chain,
)
while not isinstance(
(action := searcher.search(initialState=state)), datamodels.Condition
(actions := searcher.search(initialState=state, top_k=5))[0],
datamodels.Condition,
):
assert isinstance(action, datamodels.Symptom)
logger.info(f"{action=}")
patient_symptom_response = chain_chainer.interaction(note, action.name)
assert isinstance(actions[0], datamodels.Symptom)
logger.info(f"{actions=}")
actions = [action for action in actions if valid_action(action, state)]

action_name = action_picker(actions=actions, state=state)

patient_symptom_response = chain_chainer.interaction(note, action_name)

new_positives = [
symptom_name_to_symptom[s.symptom_match.strip()]
for s in patient_symptom_response
Expand All @@ -52,10 +72,36 @@ def a():
state.pertinent_pos.update(new_positives)
state.pertinent_neg.update(new_negatives)

action = actions[0]
diagnosis = action
logger.info(f"Diagnosis: {diagnosis}")
print(chain_chainer.interaction("fever"))


def valid_action(
action: Union[datamodels.Condition, datamodels.Symptom],
state: action_states.SimulationNextActionState,
) -> bool:
return (
not isinstance(action, datamodels.Condition)
and action not in state.pertinent_pos
and action not in state.pertinent_neg
)


def get_action_picker_inputs(
actions: List[datamodels.Symptom], state
) -> Dict[str, str]:
return {
"positive_symptoms": " | ".join(
[action.name for action in state.pertinent_pos]
),
"negative_symptoms": " | ".join(
[action.name for action in state.pertinent_neg]
),
"symptoms": " | ".join([action.name for action in actions]),
}


if __name__ == "__main__":
main()
35 changes: 1 addition & 34 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ loguru = "^0.7.0"
anthropic = "^0.3.6"
nest-asyncio = "^1.5.7"
torch = "^2.0.1"
pip = "^23.2.1"
install = "^1.3.5"
nvidia-cuda-cupti-cu11 = "^11.8.87"
transformers = "^4.31.0"
fastapi = "^0.100.1"
uvicorn = "^0.23.1"
Expand Down

0 comments on commit 19807aa

Please sign in to comment.