-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from fadynakhla/fady/pick-action
Fady/pick action
- Loading branch information
Showing
10 changed files
with
193 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import Dict, List | ||
|
||
import loguru | ||
from langchain.chat_models import ChatAnthropic | ||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
|
||
from dr_claude import datamodels | ||
|
||
|
||
logger = loguru.logger | ||
|
||
|
||
_decision_template = """You will be given the context of a patient through a list of positive and negative symptoms. | ||
You will then be given a set of symptoms that an intelligent system has predicted are the next best questions to ask the patient. | ||
Your job is to choose the best action. | ||
Known patient state: | ||
positive symptoms: {positive_symptoms} | ||
negative symptoms: {negative_symptoms} | ||
Symptoms to consider: {symptoms} | ||
What is the the best symptom to ask the patient about? | ||
Remember to ensure the chosen symptom exactly matches one of those you are asked to consider. Do not provide any other information or text. | ||
Chosen Symptom: | ||
""" | ||
|
||
DECISION_PROMPT = PromptTemplate.from_template(_decision_template) | ||
|
||
|
||
class DecisionClaude: | ||
def __init__(self): | ||
self.chain = get_decision_claude() | ||
|
||
def __call__(self, actions: List[datamodels.Symptom], state): | ||
inputs = self.get_action_picker_inputs(actions, state) | ||
response = self.chain(inputs) | ||
action = response["text"].strip() | ||
logger.info(f"Chosen Action: {action}") | ||
return action | ||
|
||
def get_action_picker_inputs( | ||
self, actions: List[datamodels.Symptom], state | ||
) -> Dict[str, str]: | ||
return { | ||
"positive_symptoms": " | ".join( | ||
[action.name for action in state.pertinent_pos] | ||
), | ||
"negative_symptoms": " | ".join( | ||
[action.name for action in state.pertinent_neg] | ||
), | ||
"symptoms": " | ".join([action.name for action in actions]), | ||
} | ||
|
||
|
||
def get_decision_claude() -> LLMChain: | ||
return LLMChain( | ||
llm=ChatAnthropic(temperature=0.0, verbose=True), | ||
prompt=DECISION_PROMPT, | ||
verbose=True, | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import math | ||
import random | ||
import time | ||
import heapq | ||
|
||
from mcts import treeNode, mcts | ||
|
||
|
||
class MultiChildMixin: | ||
def search(self, initialState, top_k): | ||
self.root = treeNode(initialState, None) | ||
|
||
if self.limitType == "time": | ||
timeLimit = time.time() + self.timeLimit / 1000 | ||
while time.time() < timeLimit: | ||
self.executeRound() | ||
else: | ||
for i in range(self.searchLimit): | ||
self.executeRound() | ||
|
||
bestChild = self.getBestChild(self.root, 0, top_k) | ||
return self.getAction(self.root, bestChild) | ||
|
||
def getBestChild(self, node, explorationValue, top_k): | ||
node_values = [] | ||
for i, child in enumerate(node.children.values()): | ||
nodeValue = ( | ||
child.totalReward / child.numVisits | ||
+ explorationValue | ||
* math.sqrt(2 * math.log(node.numVisits) / child.numVisits) | ||
) | ||
# Use negative value because heapq is a min heap, i to break ties | ||
heapq.heappush(node_values, (-nodeValue, i, child)) | ||
# Keep only the top_k node values | ||
if len(node_values) > top_k: | ||
heapq.heappop(node_values) | ||
|
||
# Return the children associated with the top_k node values | ||
top_k_nodes = [heapq.heappop(node_values)[2] for _ in range(len(node_values))] | ||
# The nodes are popped in ascending order, so reverse the list | ||
top_k_nodes.reverse() | ||
return top_k_nodes | ||
|
||
def getAction(self, root, bestChild): | ||
nodes = [] | ||
for action, node in root.children.items(): | ||
if node in bestChild: | ||
nodes.append(action) | ||
return nodes | ||
|
||
|
||
class MultiChoiceMCTS(MultiChildMixin, mcts): | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters