diff --git a/libs/manubot_ai_editor/env_vars.py b/libs/manubot_ai_editor/env_vars.py index 0ed9a06..a888bc6 100644 --- a/libs/manubot_ai_editor/env_vars.py +++ b/libs/manubot_ai_editor/env_vars.py @@ -16,7 +16,7 @@ OPENAI_API_KEY = "OPENAI_API_KEY" # Language model to use. For example, "text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", etc -# The tool currently supports the "chat/completions", "completions", and "edits" endpoints, and you can check +# The tool currently supports the "chat/completions" and "completions" endpoints, and you can check # compatible models here: https://platform.openai.com/docs/models/model-endpoint-compatibility LANGUAGE_MODEL = "AI_EDITOR_LANGUAGE_MODEL" diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index c25dd68..80a2dcc 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -5,7 +5,8 @@ import time import json -import openai +from langchain_openai import OpenAI, ChatOpenAI +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from manubot_ai_editor import env_vars @@ -141,12 +142,13 @@ def __init__( super().__init__() # make sure the OpenAI API key is set - openai.api_key = openai_api_key + if openai_api_key is None: + # attempt to get the OpenAI API key from the environment, since one + # wasn't specified as an argument + openai_api_key = os.environ.get(env_vars.OPENAI_API_KEY, None) - if openai.api_key is None: - openai.api_key = os.environ.get(env_vars.OPENAI_API_KEY, None) - - if openai.api_key is None or openai.api_key.strip() == "": + # if it's *still* not set, bail + if openai_api_key is None or openai_api_key.strip() == "": raise ValueError( f"OpenAI API key not found. Please provide it as parameter " f"or set it as an the environment variable " @@ -221,7 +223,7 @@ def __init__( self.title = title self.keywords = keywords if keywords is not None else [] - # adjust options if edits or chat endpoint was selected + # adjust options if chat endpoint was selected self.endpoint = "chat" if model_engine.startswith( @@ -229,9 +231,6 @@ def __init__( ): self.endpoint = "completions" - if "-edit-" in model_engine: - self.endpoint = "edits" - print(f"Language model: {model_engine}") print(f"Model endpoint used: {self.endpoint}") @@ -253,6 +252,18 @@ def __init__( self.several_spaces_pattern = re.compile(r"\s+") + if self.endpoint == "chat": + client_cls = ChatOpenAI + else: + client_cls = OpenAI + + # construct the OpenAI client after all the rest of + # the settings above have been processed + self.client = client_cls( + api_key=openai_api_key, + **self.model_parameters, + ) + def get_prompt( self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None ) -> str | tuple[str, str]: @@ -268,13 +279,9 @@ def get_prompt( resolved_prompt: prompt resolved via ai-revision config, if available Returns: - If self.endpoint != "edits", then returns a string with the prompt to be used by the model for the revision of the paragraph. + A string with the prompt to be used by the model for the revision of the paragraph. It contains two paragraphs of text: the command for the model ("Revise...") and the paragraph to revise. - - If self.endpoint == "edits", then returns a tuple with two strings: - 1) the instructions to be used by the model for the revision of the paragraph, - 2) the paragraph to revise. """ # prompts are resolved in the following order, with the first satisfied @@ -310,8 +317,6 @@ def get_prompt( f"Using custom prompt from environment variable '{env_vars.CUSTOM_PROMPT}'" ) - # FIXME: if {paragraph_text} is in the prompt, this won't work for the edits endpoint - # a simple workaround is to remove {paragraph_text} from the prompt prompt = custom_prompt.format(**placeholders) elif resolved_prompt: # use the resolved prompt from the ai-revision config files, if available @@ -384,14 +389,10 @@ def get_prompt( if custom_prompt is None: prompt = self.several_spaces_pattern.sub(" ", prompt).strip() - if self.endpoint != "edits": - if custom_prompt is not None and "{paragraph_text}" in custom_prompt: - return prompt + if custom_prompt is not None and "{paragraph_text}" in custom_prompt: + return prompt - return f"{prompt}.\n\n{paragraph_text.strip()}" - else: - prompt = prompt.replace("the following paragraph", "this paragraph") - return f"{prompt}.", paragraph_text.strip() + return f"{prompt}.\n\n{paragraph_text.strip()}" def get_max_tokens(self, paragraph_text: str, fraction: float = 2.0) -> int: """ @@ -465,6 +466,22 @@ def get_max_tokens_from_error_message(error_message: str) -> dict[str, int] | No } def get_params(self, paragraph_text, section_name, resolved_prompt=None): + """ + Given the paragraph text and section name, produces parameters that are + used when invoking an LLM via an API. + + The specific parameters vary depending on the endpoint being used, which + is determined by the model that was chosen when GPT3CompletionModel was + instantiated. + + Args: + paragraph_text: The text of the paragraph to be revised. + section_name: The name of the section the paragraph belongs to. + resolved_prompt: The prompt resolved via ai-revision config files, if available. + + Returns: + A dictionary of parameters to be used when invoking an LLM API. + """ max_tokens = self.get_max_tokens(paragraph_text) prompt = self.get_prompt(paragraph_text, section_name, resolved_prompt) @@ -472,14 +489,7 @@ def get_params(self, paragraph_text, section_name, resolved_prompt=None): "n": 1, } - if self.endpoint == "edits": - params.update( - { - "instruction": prompt[0], - "input": prompt[1], - } - ) - elif self.endpoint == "chat": + if self.endpoint == "chat": params.update( { "messages": [ @@ -502,19 +512,23 @@ def get_params(self, paragraph_text, section_name, resolved_prompt=None): return params - def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolved_prompt=None): + def revise_paragraph( + self, paragraph_text: str, section_name: str = None, resolved_prompt=None + ): """ It revises a paragraph using GPT-3 completion model. Arguments: paragraph_text (str): Paragraph text to revise. - section_name (str): Section name of the paragraph. - throw_error (bool): If True, it throws an error if the API call fails. - If False, it returns the original paragraph text. + section_name (str): Section name of the paragrap + resolved_prompt (str): Prompt resolved via ai-revision config files, if available. Returns: Revised paragraph text. """ + + # based on the paragraph text to revise and the section to which it + # belongs, constructs parameters that we'll use to query the LLM's API params = self.get_params(paragraph_text, section_name, resolved_prompt) retry_count = 0 @@ -526,17 +540,33 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv flush=True, ) - if self.endpoint == "edits": - completions = openai.Edit.create(**params) - elif self.endpoint == "chat": - completions = openai.ChatCompletion.create(**params) - else: - completions = openai.Completion.create(**params) + # map the prompt to langchain's prompt types, based on what + # kind of endpoint we're using + if "messages" in params: + # map the messages to langchain's message types + # based on the 'role' field + prompt = [ + ( + HumanMessage(content=msg["content"]) + if msg["role"] == "user" + else SystemMessage(content=msg["content"]) + ) + for msg in params["messages"] + ] + elif "prompt" in params: + prompt = [HumanMessage(content=params["prompt"])] + + response = self.client.invoke( + input=prompt, + max_tokens=params.get("max_tokens"), + stop=params.get("stop"), + ) - if self.endpoint == "chat": - message = completions.choices[0].message.content.strip() + if isinstance(response, BaseMessage): + message = response.content.strip() else: - message = completions.choices[0].text.strip() + message = response.strip() + except Exception as e: error_message = str(e) print(f"Error: {error_message}") @@ -583,10 +613,10 @@ class DebuggingManuscriptRevisionModel(GPT3CompletionModel): """ def __init__(self, *args, **kwargs): - if 'title' not in kwargs or kwargs['title'] is None: - kwargs['title'] = "Debugging Title" - if 'keywords' not in kwargs or kwargs['keywords'] is None: - kwargs['keywords'] = ["debugging", "keywords"] + if "title" not in kwargs or kwargs["title"] is None: + kwargs["title"] = "Debugging Title" + if "keywords" not in kwargs or kwargs["keywords"] is None: + kwargs["keywords"] = ["debugging", "keywords"] super().__init__(*args, **kwargs) diff --git a/libs/manubot_ai_editor/prompt_config.py b/libs/manubot_ai_editor/prompt_config.py index d2d9f6e..695952e 100644 --- a/libs/manubot_ai_editor/prompt_config.py +++ b/libs/manubot_ai_editor/prompt_config.py @@ -47,9 +47,9 @@ def __init__(self, config_dir: str | Path, title: str, keywords: str) -> None: # specify filename-to-prompt mappings; if both are present, we use # self.config.files, but warn the user that they should only use one if ( - self.prompts_files is not None and - self.config is not None and - self.config.get('files', {}).get('matchings') is not None + self.prompts_files is not None + and self.config is not None + and self.config.get("files", {}).get("matchings") is not None ): print( "WARNING: Both 'ai-revision-config.yaml' and 'ai-revision-prompts.yaml' specify filename-to-prompt mappings. " @@ -93,7 +93,7 @@ def _load_custom_prompts(self) -> tuple[dict, dict]: # same as _load_config, if no config folder was specified, we just if self.config_dir is None: return (None, None) - + prompt_file_path = os.path.join(self.config_dir, "ai-revision-prompts.yaml") try: @@ -150,7 +150,7 @@ def get_prompt_for_filename( # ai-revision-prompts.yaml specifies prompts_files, then files.matchings # takes precedence. # (the user is notified of this in a validation warning in __init__) - + # then, consult ai-revision-config.yaml's 'matchings' collection if a # match is found, use the prompt ai-revision-prompts.yaml for entry in get_obj_path(self.config, ("files", "matchings"), missing=[]): @@ -169,7 +169,10 @@ def get_prompt_for_filename( if resolved_prompt is not None: resolved_prompt = resolved_prompt.strip() - return ( resolved_prompt, m, ) + return ( + resolved_prompt, + m, + ) # since we haven't found a match yet, consult ai-revision-prompts.yaml's # 'prompts_files' collection @@ -185,11 +188,10 @@ def get_prompt_for_filename( resolved_default_prompt = None if use_default and self.prompts is not None: resolved_default_prompt = self.prompts.get( - get_obj_path(self.config, ("files", "default_prompt")), - None + get_obj_path(self.config, ("files", "default_prompt")), None ) if resolved_default_prompt is not None: resolved_default_prompt = resolved_default_prompt.strip() - + return (resolved_default_prompt, None) diff --git a/setup.py b/setup.py index 531fc0b..bbcfb83 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name="manubot-ai-editor", - version="0.5.2", + version="0.5.3", author="Milton Pividori", author_email="miltondp@gmail.com", description="A Manubot plugin to revise a manuscript using GPT-3", @@ -25,7 +25,8 @@ ], python_requires=">=3.10", install_requires=[ - "openai==0.28", + "langchain-core~=0.3.6", + "langchain-openai~=0.2.0", "pyyaml", ], classifiers=[ diff --git a/tests/test_editor.py b/tests/test_editor.py index 60bb778..1068a6a 100644 --- a/tests/test_editor.py +++ b/tests/test_editor.py @@ -610,9 +610,7 @@ def test_revise_methods_with_equation_that_was_alrady_revised( # GPT3CompletionModel(None, None), ], ) -def test_revise_methods_mutator_epistasis_paper( - tmp_path, model, filename -): +def test_revise_methods_mutator_epistasis_paper(tmp_path, model, filename): """ This papers has several test cases: - it ends with multiple blank lines @@ -635,7 +633,7 @@ def test_revise_methods_mutator_epistasis_paper( ) assert ( - r""" + r""" %%% PARAGRAPH START %%% Briefly, we identified private single-nucleotide mutations in each BXD that were absent from all other BXDs, as well as from the C57BL/6J and DBA/2J parents. We required each private variant to be meet the following criteria: @@ -651,11 +649,11 @@ def test_revise_methods_mutator_epistasis_paper( * must occur on a parental haplotype that was inherited by at least one other BXD at the same locus; these other BXDs must be homozygous for the reference allele at the variant site %%% PARAGRAPH END %%% """.strip() - in open(tmp_path / filename).read() + in open(tmp_path / filename).read() ) - + assert ( - r""" + r""" ### Extracting mutation signatures We used SigProfilerExtractor (v.1.1.21) [@PMID:30371878] to extract mutation signatures from the BXD mutation data. @@ -678,11 +676,11 @@ def test_revise_methods_mutator_epistasis_paper( ### Comparing mutation spectra between Mouse Genomes Project strains """.strip() - in open(tmp_path / filename).read() + in open(tmp_path / filename).read() ) - + assert ( - r""" + r""" %%% PARAGRAPH START %%% We investigated the region implicated by our aggregate mutation spectrum distance approach on chromosome 6 by subsetting the joint-genotyped BXD VCF file (European Nucleotide Archive accession PRJEB45429 [@url:https://www.ebi.ac.uk/ena/browser/view/PRJEB45429]) using `bcftools` [@PMID:33590861]. We defined the candidate interval surrounding the cosine distance peak on chromosome 6 as the 90% bootstrap confidence interval (extending from approximately 95 Mbp to 114 Mbp). @@ -693,7 +691,7 @@ def test_revise_methods_mutator_epistasis_paper( java -Xmx16g -jar /path/to/snpeff/jarfile GRCm38.75 /path/to/bxd/vcf > /path/to/uncompressed/output/vcf ``` """.strip() - in open(tmp_path / filename).read() + in open(tmp_path / filename).read() ) diff --git a/tests/test_model_basics.py b/tests/test_model_basics.py index 8242ca9..54ec074 100644 --- a/tests/test_model_basics.py +++ b/tests/test_model_basics.py @@ -9,7 +9,6 @@ import pytest from manubot_ai_editor.editor import ManuscriptEditor, env_vars -from manubot_ai_editor import models from manubot_ai_editor.models import GPT3CompletionModel, RandomManuscriptRevisionModel MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" @@ -32,12 +31,12 @@ def test_model_object_init_without_openai_api_key(): @mock.patch.dict("os.environ", {env_vars.OPENAI_API_KEY: "env_var_test_value"}) def test_model_object_init_with_openai_api_key_as_environment_variable(): - GPT3CompletionModel( + model = GPT3CompletionModel( title="Test title", keywords=["test", "keywords"], ) - assert models.openai.api_key == "env_var_test_value" + assert model.client.openai_api_key.get_secret_value() == "env_var_test_value" def test_model_object_init_with_openai_api_key_as_parameter(): @@ -46,30 +45,26 @@ def test_model_object_init_with_openai_api_key_as_parameter(): if env_vars.OPENAI_API_KEY in os.environ: os.environ.pop(env_vars.OPENAI_API_KEY) - GPT3CompletionModel( + model = GPT3CompletionModel( title="Test title", keywords=["test", "keywords"], openai_api_key="test_value", ) - from manubot_ai_editor import models - - assert models.openai.api_key == "test_value" + assert model.client.openai_api_key.get_secret_value() == "test_value" finally: os.environ = _environ @mock.patch.dict("os.environ", {env_vars.OPENAI_API_KEY: "env_var_test_value"}) def test_model_object_init_with_openai_api_key_as_parameter_has_higher_priority(): - GPT3CompletionModel( + model = GPT3CompletionModel( title="Test title", keywords=["test", "keywords"], openai_api_key="test_value", ) - from manubot_ai_editor import models - - assert models.openai.api_key == "test_value" + assert model.client.openai_api_key.get_secret_value() == "test_value" def test_model_object_init_default_language_model(): diff --git a/tests/test_model_get_prompt.py b/tests/test_model_get_prompt.py index 538e654..438a16d 100644 --- a/tests/test_model_get_prompt.py +++ b/tests/test_model_get_prompt.py @@ -28,36 +28,6 @@ def test_get_prompt_for_abstract(): assert " " not in prompt -def test_get_prompt_for_abstract_edit_endpoint(): - manuscript_title = "Title of the manuscript to be revised" - manuscript_keywords = ["keyword0", "keyword1", "keyword2"] - - model = GPT3CompletionModel( - title=manuscript_title, - keywords=manuscript_keywords, - model_engine="text-davinci-edit-001", - ) - - paragraph_text = "Text of the abstract. " - - instruction, paragraph = model.get_prompt(paragraph_text, "abstract") - assert instruction is not None - assert isinstance(instruction, str) - assert paragraph is not None - assert isinstance(paragraph, str) - - assert "this paragraph" in instruction - assert "abstract" in instruction - assert f"'{manuscript_title}'" in instruction - assert f"{manuscript_keywords[0]}" in instruction - assert f"{manuscript_keywords[1]}" in instruction - assert f"{manuscript_keywords[2]}" in instruction - assert " " not in instruction - assert instruction.startswith("Revise") - - assert paragraph_text.strip() == paragraph - - def test_get_prompt_for_introduction(): manuscript_title = "Title of the manuscript to be revised" manuscript_keywords = ["keyword0", "keyword1", "keyword2"] diff --git a/tests/test_prompt_config.py b/tests/test_prompt_config.py index 7f68702..4d32b6a 100644 --- a/tests/test_prompt_config.py +++ b/tests/test_prompt_config.py @@ -5,7 +5,7 @@ from manubot_ai_editor.models import ( GPT3CompletionModel, RandomManuscriptRevisionModel, - DebuggingManuscriptRevisionModel + DebuggingManuscriptRevisionModel, ) from manubot_ai_editor.prompt_config import IGNORE_FILE import pytest @@ -13,7 +13,9 @@ from utils.dir_union import mock_unify_open MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full" / "content" -MANUSCRIPTS_CONFIG_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full" / "ci" +MANUSCRIPTS_CONFIG_DIR = ( + Path(__file__).parent / "manuscripts" / "phenoplier_full" / "ci" +) # check that this path exists and resolve it @@ -42,7 +44,9 @@ def test_create_manuscript_editor(): # check that we can resolve a file to a prompt, and that it's the correct prompt -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR) +) def test_resolve_prompt(): content_dir = MANUSCRIPTS_DIR.resolve(strict=True) config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True) @@ -100,7 +104,9 @@ def test_resolve_prompt(): # test that we get the default prompt with a None match object for a # file we don't recognize -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR) +) def test_resolve_default_prompt_unknown_file(): content_dir = MANUSCRIPTS_DIR.resolve(strict=True) config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True) @@ -114,7 +120,9 @@ def test_resolve_default_prompt_unknown_file(): # check that a file we don't recognize gets match==None and the 'default' prompt # from the ai-revision-config.yaml file -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR) +) def test_unresolved_gets_default_prompt(): content_dir = MANUSCRIPTS_DIR.resolve(strict=True) config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True) @@ -150,7 +158,9 @@ def test_unresolved_gets_default_prompt(): # - Both ai-revision-config.yaml and ai-revision-prompts.yaml specify filename matchings # (conflicting_promptsfiles_matchings) CONFLICTING_PROMPTSFILES_MATCHINGS_DIR = ( - Path(__file__).parent / "config_loader_fixtures" / "conflicting_promptsfiles_matchings" + Path(__file__).parent + / "config_loader_fixtures" + / "conflicting_promptsfiles_matchings" ) # --- # test ManuscriptEditor.prompt_config sub-attributes are set correctly @@ -178,7 +188,9 @@ def test_no_config_unloaded(): assert editor.prompt_config.config is None -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, ONLY_REV_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, ONLY_REV_PROMPTS_DIR) +) def test_only_rev_prompts_loaded(): editor = get_editor() @@ -188,7 +200,9 @@ def test_only_rev_prompts_loaded(): assert editor.prompt_config.config is None -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR) +) def test_both_prompts_loaded(): editor = get_editor() @@ -211,7 +225,8 @@ def test_single_generic_loaded(): @mock.patch( - "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, CONFLICTING_PROMPTSFILES_MATCHINGS_DIR) + "builtins.open", + mock_unify_open(MANUSCRIPTS_CONFIG_DIR, CONFLICTING_PROMPTSFILES_MATCHINGS_DIR), ) def test_conflicting_sources_warning(capfd): """ @@ -234,7 +249,7 @@ def test_conflicting_sources_warning(capfd): # for this test, we define both prompts_files and files.matchings which # creates a conflict that produces the warning we're looking for assert editor.prompt_config.prompts_files is not None - assert editor.prompt_config.config['files']['matchings'] is not None + assert editor.prompt_config.config["files"]["matchings"] is not None expected_warning = ( "WARNING: Both 'ai-revision-config.yaml' and " @@ -262,11 +277,13 @@ def test_conflicting_sources_warning(capfd): RandomManuscriptRevisionModel(), DebuggingManuscriptRevisionModel( title="Test title", keywords=["test", "keywords"] - ) + ), # GPT3CompletionModel(None, None), ], ) -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR) +) def test_revise_entire_manuscript(tmp_path, model): print(f"\n{str(tmp_path)}\n") me = get_editor() @@ -284,7 +301,9 @@ def test_revise_entire_manuscript(tmp_path, model): assert len(output_md_files) == 9 -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR) +) def test_revise_entire_manuscript_includes_title_keywords(tmp_path): from os.path import basename @@ -317,8 +336,12 @@ def test_revise_entire_manuscript_includes_title_keywords(tmp_path): with open(output_md_file, "r") as f: content = f.read() - assert me.title in content, f"not found in filename: {basename(output_md_file)}" - assert ", ".join(me.keywords) in content, f"not found in filename: {basename(output_md_file)}" + assert ( + me.title in content + ), f"not found in filename: {basename(output_md_file)}" + assert ( + ", ".join(me.keywords) in content + ), f"not found in filename: {basename(output_md_file)}" # ============================================================================== @@ -329,7 +352,11 @@ def test_revise_entire_manuscript_includes_title_keywords(tmp_path): Path(__file__).parent / "config_loader_fixtures" / "prompt_propogation" ) -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR)) + +@mock.patch( + "builtins.open", + mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR), +) def test_prompts_in_final_result(tmp_path): """ Tests that the prompts are making it into the final resulting .md files. @@ -348,9 +375,7 @@ def test_prompts_in_final_result(tmp_path): """ me = get_editor() - model = DebuggingManuscriptRevisionModel( - title=me.title, keywords=me.keywords - ) + model = DebuggingManuscriptRevisionModel(title=me.title, keywords=me.keywords) output_folder = tmp_path assert output_folder.exists() @@ -361,7 +386,8 @@ def test_prompts_in_final_result(tmp_path): files_to_prompts = { "00.front-matter.md": "This is the front-matter prompt.", "01.abstract.md": "This is the abstract prompt", - "02.introduction.md": "This is the introduction prompt for the paper titled '%s'." % me.title, + "02.introduction.md": "This is the introduction prompt for the paper titled '%s'." + % me.title, # "04.00.results.md": "This is the results prompt", "04.05.00.results_framework.md": "This is the results_framework prompt", "04.05.01.crispr.md": "This is the crispr prompt", @@ -389,15 +415,26 @@ def test_prompts_in_final_result(tmp_path): # to save on time/cost, we use a version of the phenoplier manuscript that only # contains the first paragraph of each section -BRIEF_MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "content" -BRIEF_MANUSCRIPTS_CONFIG_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "ci" +BRIEF_MANUSCRIPTS_DIR = ( + Path(__file__).parent + / "manuscripts" + / "phenoplier_full_only_first_para" + / "content" +) +BRIEF_MANUSCRIPTS_CONFIG_DIR = ( + Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "ci" +) PROMPT_PROPOGATION_CONFIG_DIR = ( Path(__file__).parent / "config_loader_fixtures" / "prompt_gpt3_e2e" ) + @pytest.mark.cost -@mock.patch("builtins.open", mock_unify_open(BRIEF_MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR)) +@mock.patch( + "builtins.open", + mock_unify_open(BRIEF_MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR), +) def test_prompts_apply_gpt3(tmp_path): """ Tests that the custom prompts are applied when actually applying @@ -408,16 +445,15 @@ def test_prompts_apply_gpt3(tmp_path): this test is marked 'cost' and requires the --runcost argument to be run, e.g. to run just this test: `pytest --runcost -k test_prompts_apply_gpt3`. - As with test_prompts_in_final_result above, files that have no input and + As with test_prompts_in_final_result above, files that have no input and thus no applied prompt are ignored. """ - me = get_editor(content_dir=BRIEF_MANUSCRIPTS_DIR, config_dir=BRIEF_MANUSCRIPTS_CONFIG_DIR) - - model = GPT3CompletionModel( - title=me.title, - keywords=me.keywords + me = get_editor( + content_dir=BRIEF_MANUSCRIPTS_DIR, config_dir=BRIEF_MANUSCRIPTS_CONFIG_DIR ) + model = GPT3CompletionModel(title=me.title, keywords=me.keywords) + output_folder = tmp_path assert output_folder.exists()