diff --git a/gollm/openai/prompts/amr_enrichment.py b/gollm/openai/prompts/amr_enrichment.py new file mode 100644 index 0000000..9960097 --- /dev/null +++ b/gollm/openai/prompts/amr_enrichment.py @@ -0,0 +1,25 @@ +ENRICH_PROMPT = """ + You are a helpful agent designed to extract metadata associated with petrinet models. \ + You will focus on extracting descriptions and units for each initial place and transition in the model. + + For context: + + In a Petri net model, initials represent the initial state of the system through the initial distribution of tokens across the places, known as the initial marking. Each place corresponds to a variable or state in the system, such as a species concentration in a reaction, and the number of tokens reflects the initial conditions of the ODEs. + Parameters define the system's rules and rates of evolution, including transition rates (analogous to reaction rates in ODEs) that determine how quickly tokens move between places. These parameters also include stoichiometric relationships, represented by the weights of arcs connecting places and transitions, dictating how many tokens are consumed or produced when a transition occurs. + + Your initials and parameters to extract are: {param_initial_dict} + + Extract descriptions and units from the following research paper: {paper_text}\n###PAPER END###\n + + Please provide your output in the following json format: + + {{'initials': {{'place1': {{'description': '...', 'units': '...'}}, 'place2': {{'description': '...', 'units': '...'}}, ...}}, 'parameters': {{'transition1': {{'description': '...', 'units': '...'}}, 'transition2': {{'description': '...', 'units': '...'}}, ...}}}} + + Ensure that units are provided in both a unicode string and mathml format like so: + + "units": {{ "expression": "1/(person*day)", "expression_mathml": "1personday" }} + + Where 'placeN' and 'transitionN' are the names of the intials and parameters to extract as found in the provided dictionary. + + Begin: + """ diff --git a/gollm/openai/tool_utils.py b/gollm/openai/tool_utils.py index 325cc21..d320805 100644 --- a/gollm/openai/tool_utils.py +++ b/gollm/openai/tool_utils.py @@ -8,6 +8,7 @@ exceeds_tokens, model_config_adapter, postprocess_oai_json, + parse_param_initials, ) from gollm.openai.prompts.petrinet_config import PETRINET_PROMPT from gollm.openai.prompts.model_card import MODEL_CARD_TEMPLATE, INSTRUCTIONS @@ -15,6 +16,7 @@ from gollm.openai.prompts.dataset_config import DATASET_PROMPT from gollm.openai.prompts.model_meta_compare import MODEL_METADATA_COMPARE_PROMPT from gollm.openai.prompts.general_instruction import GENERAL_INSTRUCTION_PROMPT +from gollm.openai.prompts.amr_enrichment import ENRICH_PROMPT from gollm.openai.react import OpenAIAgent, AgentExecutor, ReActManager from gollm.openai.toolsets import DatasetConfig @@ -29,7 +31,6 @@ def escape_curly_braces(text: str): def model_config_chain(research_paper: str, amr: str) -> dict: print("Reading model config from research paper: {}".format(research_paper[:100])) - research_paper = remove_references(research_paper) research_paper = normalize_greek_alphabet(research_paper) # probonto ontology file copied from https://github.com/gyorilab/mira/blob/e468059089681c7cd457acc51821b5bd1074df04/mira/dkg/resources/probonto.json @@ -59,6 +60,30 @@ def model_config_chain(research_paper: str, amr: str) -> dict: config = postprocess_oai_json(output.choices[0].message.content) return model_config_adapter(config) + +def amr_enrichment_chain(amr: str, research_paper:str) -> dict: + amr_param_states = parse_param_initials(amr) + prompt = ENRICH_PROMPT.format( + param_initial_dict=amr_param_states, + paper_text=escape_curly_braces(research_paper) + ) + client = OpenAI() + output = client.chat.completions.create( + model="gpt-4o-2024-05-13", + max_tokens=4000, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + seed=123, + temperature=0, + response_format={"type": "json_object"}, + messages=[ + {"role": "user", "content": prompt}, + ], + ) + return postprocess_oai_json(output.choices[0].message.content) + + def model_card_chain(research_paper: str = None, amr: str = None) -> dict: print("Creating model card...") assert research_paper or amr, "Either research_paper or amr must be provided." diff --git a/gollm/utils.py b/gollm/utils.py index 1efe5e6..4e11f64 100644 --- a/gollm/utils.py +++ b/gollm/utils.py @@ -11,6 +11,28 @@ def remove_references(text: str) -> str: new_text = re.sub(pattern, "", text) return new_text.strip() +def c(amr: dict): + try: + ode = amr['semantics']['ode'] + except KeyError: + raise KeyError("ODE semantics not found in AMR, please provide a valide AMR with structure semantics.ode") + + assert 'parameters' in ode, "No parameters found in ODE semantics, please provide a valid AMR with structure semtnatics.ode.parameters" + assert 'initials' in ode, "No initials found in ODE semantics, please provide a valid AMR with structure semantics.ode.initials" + + params = ode['parameters'] + + assert all(['id' in p.keys() for p in params]), "All parameters must have an 'id' key" + + param_ids = [p['id'] for p in params if p is not None and p.get('id')] + + initials = ode['initials'] + + assert all(['target' in i.keys() for i in initials]), "All initials must have an 'id' key" + + initial_ids = [i['target'] for i in initials if i is not None and i.get('target')] + + return {'initial_names': initial_ids, 'param_names': param_ids} def parse_json_from_markdown(text): print("Stripping markdown...")