Skip to content

Commit

Permalink
Merge pull request #28 from fadynakhla/arthur/finalizing
Browse files Browse the repository at this point in the history
Arthur/finalizing
  • Loading branch information
ArthurBook authored Jul 30, 2023
2 parents 3bd2cc3 + 4095e41 commit 5a3e5bd
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 17 deletions.
50 changes: 47 additions & 3 deletions dr_claude/application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Dict, Any, List
from typing import Optional, Tuple, Union, Dict, Any, List
import asyncio
import nest_asyncio
from uuid import UUID
Expand All @@ -23,6 +23,10 @@
import json
from loguru import logger

import transformers

transformers.set_seed(42)

app = FastAPI()


Expand Down Expand Up @@ -150,7 +154,12 @@ async def websocket_endpoint(websocket: WebSocket) -> None:
try:
message = await receive_message(websocket)
logger.info("Received {}", message)
await run_chain(note=message["content"], chainer=chainer)
diagnosis = await run_chain(
websocket=websocket,
note=message["content"],
chainer=chainer,
)
await websocket.send_json({"condition": diagnosis.name})
except WebSocketDisconnect:
logger.info("websocket disconnect")
break
Expand Down Expand Up @@ -180,7 +189,12 @@ async def receive_message(websocket: WebSocket):
return message_dict


async def run_chain(note: str, chainer: chaining_the_chains.ChainChainer):
K = 5


async def run_chain(
websocket: WebSocket, note: str, chainer: chaining_the_chains.ChainChainer
) -> Optional[datamodels.Condition]:
matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)
state = action_states.SimulationNextActionState(matrix, discount_rate=1e-9)

Expand All @@ -190,12 +204,18 @@ async def run_chain(note: str, chainer: chaining_the_chains.ChainChainer):
timeLimit=3000, rolloutPolicy=rollout_policy
)

q_counter = 0
while not isinstance(
(actions := searcher.search(initialState=state, top_k=5))[0],
datamodels.Condition,
):
top_k = get_top_k_condition_probas(state, k=K)
await websocket.send_json({"brain": [(c.name, p) for c, p in top_k]})

q_counter += 1
if q_counter > 10:
await websocket.send_json({"condition": "I'm sorry, I'm not sure."})
await websocket.close()
return
assert isinstance(actions[0], datamodels.Symptom)
logger.info(f"{actions=}")
Expand All @@ -206,11 +226,28 @@ async def run_chain(note: str, chainer: chaining_the_chains.ChainChainer):
new_positives, new_negatives = make_new_symptoms(
action_name, patient_symptom_response
)

logger.info(f"{new_positives=}")
logger.info(f"{new_negatives=}")
state.pertinent_pos.update(new_positives)
state.pertinent_neg.update(new_negatives)

diagnosis = actions[0]
logger.info(f"Diagnosis: {diagnosis}")
return diagnosis


def get_top_k_condition_probas(
state: action_states.SimulationNextActionState, k: int
) -> List[Tuple[datamodels.Condition, float]]:
condition_probas = state.getConditionProbabilityDict()
top_k = sorted(
condition_probas.items(),
key=lambda x: x[1],
reverse=True,
)[:K]
logger.info(f"{top_k=}")
return top_k


def make_new_symptoms(
Expand Down Expand Up @@ -239,3 +276,10 @@ def make_new_symptoms(
new_positives = []
new_negatives = [symptom_name_to_symptom[action_name.strip()]]
return new_positives, new_negatives


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
...
5 changes: 1 addition & 4 deletions dr_claude/chaining_the_chains.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from typing import Any, Dict, List
from typing import List
from dr_claude import datamodels
from langchain import LLMChain

from dr_claude.chains import doctor
from dr_claude.chains import patient
from dr_claude.chains import matcher
from dr_claude.chains import prompts
from dr_claude.retrieval import retriever
from loguru import logger
import time
Expand Down
7 changes: 6 additions & 1 deletion dr_claude/chains/doctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

_doc_prompt_template = (
"You are an insightful and inquisitive doctor. You are with a patient and need to inquire about a specific symptom: {symptom}.\n\n"
"Compose a single, direct question that exclusively probes the presence of this particular symptom. Ensure your response contains only this question, with no additional commentary or elements. The entire response should be the question itself."
"Compose a single, direct question that exclusively probes the presence of this particular symptom. "
"Ensure your response contains only this question, with no additional commentary or elements. The entire response should be the question itself."
"Keep the question simple. For instance, if the symptom was shortness of breath, you can ask 'Are you experiencing any shortness of breath?'"
"\nIf the symptom was abdominal cramps, you can ask 'Are have you had any abdominal cramps lately?'"
"\n\nNow, phrase a question that lets you confirm or reject whether the patient has the symptom {symptom}."
"\n\nQuestion:"
)
DOC_PROMPT = PromptTemplate.from_template(_doc_prompt_template)

Expand Down
17 changes: 10 additions & 7 deletions dr_claude/chains/patient.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
# "Answer:"
# )
_patient_prompt_template = (
"You are a patient who is visiting their doctor. Your medical condition is detailed as follows:\n\n"
"We are playing game where your role is to act as if a patient who is visiting their doctor."
"In the game, the doctor will ask you questions to try to find out what your condition is. "
"Your medical condition is detailed as follows:\n\n"
"{medical_note}\n\n"
"The doctor will ask you questions about your symptoms. It's important to answer each question one at a time and strictly according to the "
"information mentioned above. If the doctor inquires about a symptom not "
"discussed in the above note, your response should be 'no'. Additionally, avoid volunteering any other information unless specifically asked.\n\n"
"Question:\n{question}\n\n"
"Please remember that your answers should be focused solely on the specific question asked, without reference to other symptoms or information not requested.\n\n"
"Answer:"
"The games rules are the following:"
"\n1) Answer each question one at a time and strictly according to the information mentioned above."
"\n2) If the doctor inquires about a symptom not discussed in the above note, your response should be 'no'."
"\n3) You must not volunteer any other information that what is asked by the doctor. This is the most important rule "
"and it will ruin the game for everyone if you tell the doctor your symptoms before explicitly asked."
"Let's start the game!"
"\nThe doctor now asks you:\n{question}\n\n"
)
PATIENT_PROMPT = PromptTemplate.from_template(_patient_prompt_template)

Expand Down
4 changes: 2 additions & 2 deletions dr_claude/claude_mcts/multi_choice_mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mcts import treeNode, mcts


class MultiChildMixin:
class MultiChildMixin(mcts):
def search(self, initialState, top_k):
self.root = treeNode(initialState, None)

Expand All @@ -21,7 +21,7 @@ def search(self, initialState, top_k):
bestChild = self.getBestChild(self.root, 0, top_k)
return self.getAction(self.root, bestChild)

def getBestChild(self, node, explorationValue, top_k):
def getBestChild(self: mcts, node, explorationValue, top_k):
node_values = []
for i, child in enumerate(node.children.values()):
nodeValue = (
Expand Down

0 comments on commit 5a3e5bd

Please sign in to comment.