Skip to content

Commit

Permalink
strip markdown to prevent json errors downstream due to oai weirdness…
Browse files Browse the repository at this point in the history
….. try again to enforce determinism
  • Loading branch information
j2whiting committed Mar 5, 2024
1 parent cf09451 commit 7de8988
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
41 changes: 30 additions & 11 deletions core/openai/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import List
from core.utils import (
remove_references,
extract_json,
normalize_greek_alphabet,
exceeds_tokens,
model_config_adapter,
postprocess_oai_json,
)
from core.openai.prompts.petrinet_config import PETRINET_PROMPT
from core.openai.prompts.model_card import MODEL_CARD_TEMPLATE, INSTRUCTIONS
Expand Down Expand Up @@ -36,13 +36,16 @@ def model_config_chain(research_paper: str, amr: str) -> dict:
client = OpenAI()
output = client.chat.completions.create(
model="gpt-4-0125-preview",
top_p=0,
max_tokens=4000,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
seed=123,
messages=[
{"role": "user", "content": prompt},
],
)
config = extract_json("{" + output.choices[0].message.content)
config = postprocess_oai_json(output.choices[0].message.content)
return model_config_adapter(config)


Expand All @@ -61,13 +64,16 @@ def model_card_chain(research_paper: str = None, amr: str = None) -> dict:
client = OpenAI()
output = client.chat.completions.create(
model="gpt-4-0125-preview",
top_p=0,
max_tokens=4000,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
seed=123,
messages=[
{"role": "user", "content": prompt},
],
)
model_card = extract_json("{" + output.choices[0].message.content)
model_card = postprocess_oai_json(output.choices[0].message.content)
if model_card is None:
return json.loads(MODEL_CARD_TEMPLATE)
return model_card
Expand All @@ -83,7 +89,10 @@ def condense_chain(query: str, chunks: List[str], max_tokens: int = 16385) -> st
client = OpenAI()
output = client.chat.completions.create(
model="gpt-3.5-turbo-0125",
top_p=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
seed=123,
max_tokens=1024,
messages=[
{"role": "user", "content": prompt},
Expand All @@ -108,10 +117,14 @@ async def amodel_card_chain(research_paper: str):
model="gpt-4-1106-preview",
messages=messages,
tools=functions,
top_p=0.0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
seed=123,
max_tokens=1024,
tool_choice=None,
)
model_card = extract_json("{" + response.choices[0].message.content)
model_card = postprocess_oai_json(response.choices[0].message.content)
if model_card is None:
return json.loads(MODEL_CARD_TEMPLATE)
return model_card
Expand Down Expand Up @@ -140,13 +153,16 @@ def config_from_dataset(amr: str, datasets: List[str]) -> str:
client = OpenAI()
output = client.chat.completions.create(
model="gpt-4-0125-preview",
top_p=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
seed=123,
max_tokens=4000,
messages=[
{"role": "user", "content": prompt},
],
)
return json.loads(output.choices[0].message.content)
return postprocess_oai_json(output.choices[0].message.content)


def compare_models(model_cards: List[str]) -> str:
Expand All @@ -156,7 +172,10 @@ def compare_models(model_cards: List[str]) -> str:
client = OpenAI()
output = client.chat.completions.create(
model="gpt-3.5-turbo-0125",
top_p=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
seed=123,
max_tokens=1024,
messages=[
{"role": "user", "content": prompt},
Expand Down
18 changes: 18 additions & 0 deletions core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ def remove_references(text: str) -> str:
return new_text.strip()


def parse_json_from_markdown(text):
print("Stripping markdown...")
json_pattern = r"```json\s*(\{.*?\})\s*```"
match = re.search(json_pattern, text, re.DOTALL)
if match:
return match.group(1)
else:
print(f"No markdown found in text: {text}")
return text


def extract_json(text: str) -> dict:
corrected_text = text.replace("{{", "{").replace("}}", "}")
try:
Expand All @@ -25,6 +36,13 @@ def extract_json(text: str) -> dict:
raise ValueError(f"Error decoding JSON: {e}\nfrom text {text}")


def postprocess_oai_json(output: str) -> dict:
output = "{" + parse_json_from_markdown(
output
) # curly bracket is used in all prompts to denote start of json.
return extract_json(output)


def normalize_greek_alphabet(text: str) -> str:
greek_to_english = {
"α": "alpha",
Expand Down

0 comments on commit 7de8988

Please sign in to comment.