Skip to content

Commit

Permalink
Merge pull request #11 from fadynakhla/wian/more-mcts
Browse files Browse the repository at this point in the history
updated data with claude weights
  • Loading branch information
WianStipp authored Jul 30, 2023
2 parents 6ac8c68 + 8b4752a commit ca54953
Show file tree
Hide file tree
Showing 6 changed files with 2,863 additions and 17 deletions.
1,735 changes: 1,735 additions & 0 deletions data/ClaudeKnowledgeBase.csv

Large diffs are not rendered by default.

78 changes: 62 additions & 16 deletions dr_claude/kb_reading.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""This module provides functionality to parse and read in the knowledge bases."""

from typing import List, Tuple
import abc
import re
import collections
from bs4 import BeautifulSoup
from langchain import LLMChain
import pandas as pd


from dr_claude import datamodels

KG_CSV_COLUMNS = ("Disease Code", "Disease", "Symptom Code", "Symptom")


class KnowledgeBaseReader(abc.ABC):
"""Abstract class for reading in disease-symptom knowledge bases."""
Expand All @@ -18,6 +22,27 @@ def load_knowledge_base(self) -> datamodels.DiseaseSymptomKnowledgeBase:
...


class CSVKnowledgeBaseReader(KnowledgeBaseReader):
def __init__(self, csv_path: str) -> None:
self._df = pd.read_csv(csv_path)

def load_knowledge_base(self) -> datamodels.DiseaseSymptomKnowledgeBase:
return make_knowledge_base_from_df(self._df)


class LLMWeightUpdateReader(KnowledgeBaseReader):
"""Update the weights of an existing KnowledgeBase using an LLMChain"""

def __init__(
self, kb_reader: KnowledgeBaseReader, weight_updater_chain: LLMChain
) -> None:
self._kb_reader = kb_reader
self._weight_updater_chain = weight_updater_chain

def load_knowledge_base(self) -> datamodels.DiseaseSymptomKnowledgeBase:
return super().load_knowledge_base()


class NYPHKnowldegeBaseReader(KnowledgeBaseReader):
"""
knowledge database of disease-symptom associations generated by an
Expand All @@ -35,7 +60,7 @@ def load_symptom_df(self) -> pd.DataFrame:
soup = BeautifulSoup(self._html_content, "html.parser")
table = soup.find("table", {"class": "MsoTableWeb3"})
rows = table.find_all("tr")
data = []
data: List = []

for row in rows[1:]:
cells = row.find_all("td")
Expand Down Expand Up @@ -68,27 +93,48 @@ def load_symptom_df(self) -> pd.DataFrame:
)
return pd.DataFrame(
transformed_rows,
columns=["Disease Code", "Disease", "Symptom Code", "Symptom"],
columns=KG_CSV_COLUMNS,
)

def load_knowledge_base(self) -> datamodels.DiseaseSymptomKnowledgeBase:
df = self.load_symptom_df()
condition_to_symptoms = collections.defaultdict(list)
curr_disease = None
for _, row in df.iterrows():
if row.Disease != curr_disease:
curr_disease = row.Disease
condition = datamodels.Condition(
name=curr_disease, umls_code=row["Disease Code"]
)
condition_to_symptoms[condition].append(
datamodels.WeightedSymptom(
name=row.Symptom, umls_code=row["Symptom Code"], weight=0.5
)
return make_knowledge_base_from_df(df)


def make_df_from_knowledge_base(
kb: datamodels.DiseaseSymptomKnowledgeBase,
) -> pd.DataFrame:
rows: List[Tuple[str, str, str, str]] = []
for condition, symptoms in kb.condition_symptoms.items():
for s in symptoms:
rows.append((condition.umls_code, condition.name, s.umls_code, s.name))
return pd.DataFrame(rows, columns=KG_CSV_COLUMNS)


def make_knowledge_base_from_df(
df: pd.DataFrame, default_weight: float = 0.5, default_noise=0.03
) -> datamodels.DiseaseSymptomKnowledgeBase:
condition_to_symptoms = collections.defaultdict(list)
curr_disease = None
for _, row in df.iterrows():
weight = row.get("Weight", default_weight)
noise = row.get("Noise", default_noise)
if row.Disease != curr_disease:
curr_disease = row.Disease
condition = datamodels.Condition(
name=curr_disease, umls_code=row["Disease Code"]
)
condition_to_symptoms[condition].append(
datamodels.WeightedSymptom(
name=row.Symptom,
umls_code=row["Symptom Code"],
weight=weight,
noise_rate=noise,
)
return datamodels.DiseaseSymptomKnowledgeBase(
condition_symptoms=condition_to_symptoms
)
return datamodels.DiseaseSymptomKnowledgeBase(
condition_symptoms=condition_to_symptoms
)


def parse_umls_string(umls_string):
Expand Down
164 changes: 164 additions & 0 deletions dr_claude/weight_updating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Dict, List, Tuple
import xml.etree.ElementTree as ET
import asyncio
from abc import abstractmethod
import pandas as pd
from tqdm import tqdm
from xml.etree.ElementTree import ParseError
from langchain.llms import Anthropic
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseOutputParser

from dr_claude import datamodels

DEFAULT_FREQ_TERM_TO_WEIGHT = {
"Very common": 0.9,
"Common": 0.6,
"Uncommon": 0.3,
"Rare": 0.1,
}


class WeightedSymptomXMLOutputParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a list.
Args:
frequency_term_to_weight: Mapping from a frequncy term to the
causal weight that it constitutes.
"""

frequency_term_to_weight: Dict[str, float] = DEFAULT_FREQ_TERM_TO_WEIGHT

@property
def _type(self) -> str:
return "xml"

def parse(self, text: str) -> List[datamodels.WeightedSymptom]:
"""Parse the output of an LLM call."""
root = ET.fromstring(text)
symptoms = []
for symptom_elem in root:
name = symptom_elem.find("name").text
frequency = symptom_elem.find("frequency").text
weight = self.frequency_term_to_weight.get(frequency, self.min_weight)
symptom = datamodels.WeightedSymptom(
umls_code="none", name=name, weight=weight
)
symptoms.append(symptom)
return symptoms

@property
def min_weight(self) -> float:
return min(self.frequency_term_to_weight.values())


def kb_to_dataframe(kb: datamodels.DiseaseSymptomKnowledgeBase) -> pd.DataFrame:
rows: List[Tuple[str, str, str, str]] = []
cols = ("Disease Code", "Disease", "Symptom Code", "Symptom", "Weight", "Noise")
for condition, symptoms in kb.condition_symptoms.items():
for s in symptoms:
rows.append(
(
condition.umls_code,
condition.name,
s.umls_code,
s.name,
s.weight,
s.noise_rate,
)
)
return pd.DataFrame(rows, columns=cols)


async def get_updated_weights(
condition_symptoms: Dict[datamodels.Condition, List[datamodels.WeightedSymptom]],
llm_chain: LLMChain,
) -> Dict[datamodels.Condition, List[datamodels.WeightedSymptom]]:
async def get_symptom_weights(
sem: asyncio.Semaphore,
condition: datamodels.Condition,
symptoms: datamodels.WeightedSymptom,
):
async with sem:
symptoms_str = ", ".join([s.name for s in symptoms])
try:
result = await llm_chain.arun(
condition=condition.name, symptoms_list=symptoms_str
)
except ParseError:
return None
return (condition, result)

sem = asyncio.Semaphore(1) # max concurrent calls
weight_updated_condition_symptoms = {}
tasks = []
for condition, symptoms in condition_symptoms.items():
task = asyncio.ensure_future(get_symptom_weights(sem, condition, symptoms))
tasks.append(task)

with tqdm(total=len(condition_symptoms)) as progress:
for f in asyncio.as_completed(tasks):
outcome = await f
if outcome is not None:
condition, result = outcome
weight_updated_condition_symptoms[condition] = result
progress.update(1)

await asyncio.gather(*tasks)
return weight_updated_condition_symptoms


def main() -> None:
from dr_claude import kb_reading

llm = Anthropic(model="claude-2", temperature=0.0, max_tokens_to_sample=2000)
prompt_template = """Here is a list of symptoms for the condition {condition}.
Symptoms: {symptoms_list}.
Here is the output schema:
<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="covidSymptoms">
<xs:complexType>
<xs:sequence>
<xs:element name="symptom" maxOccurs="unbounded">
<xs:complexType>
<xs:sequence>
<xs:element name="name" type="xs:string"/>
<xs:element name="frequency">
<xs:simpleType>
<xs:restriction base="xs:string">
<xs:enumeration value="Very common"/>
<xs:enumeration value="Common"/>
<xs:enumeration value="Uncommon"/>
<xs:enumeration value="Rare"/>
</xs:restriction>
</xs:simpleType>
</xs:element>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>
Please parse the symptoms into the above schema, assigning a correct frequency value to each symptom.
"""
prompt = PromptTemplate.from_template(prompt_template)

llm_chain = LLMChain(
llm=llm, prompt=prompt, output_parser=WeightedSymptomXMLOutputParser()
)
reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")
kb = reader.load_knowledge_base()
final_result = asyncio.run(get_updated_weights(kb.condition_symptoms, llm_chain))
weight_updated_kb = datamodels.DiseaseSymptomKnowledgeBase(
condition_symptoms=final_result
)
kb_to_dataframe(weight_updated_kb).to_csv("data/ClaudeKnowledgeBase-1.csv")


if __name__ == "__main__":
main()
Loading

0 comments on commit ca54953

Please sign in to comment.