Skip to content

Commit

Permalink
use the dynamic task name for prompt template
Browse files Browse the repository at this point in the history
  • Loading branch information
p3nGu1nZz committed Sep 12, 2024
1 parent 76dd3da commit 29bf42c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
1 change: 0 additions & 1 deletion oproof/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,3 @@ class Const:
ARG_PROMPT_TEXT_HELP = "Input prompt"
ARG_RESPONSE_TEXT = "response"
ARG_RESPONSE_TEXT_HELP = "Input response"
VALIDATE_TASK = "validate"
5 changes: 4 additions & 1 deletion oproof/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run(parsed_args):
Log.setup(parsed_args.debug)
if parsed_args.debug:
Log.start_main_function()
main._execute(parsed_args.prompt, parsed_args.response, parsed_args.debug, parsed_args.prompt)
main._execute(parsed_args.prompt, parsed_args.response, parsed_args.debug, parsed_args.prompt, parsed_args.prompt)
except Exception as e:
handle_error(e, parsed_args.debug)

Expand All @@ -41,6 +41,9 @@ def _execute(self, prompt: str, response: str, debug: bool, include_prompts: boo
console.print(JSON(json_output))
except ValidationError as e:
handle_error(e, debug)
except KeyError as e:
error_message = f"Missing key in validation result: {e}"
handle_error(error_message, debug)
except Exception as e:
handle_error(e, debug)
Log.error("Terminating script due to critical error.")
Expand Down
10 changes: 6 additions & 4 deletions oproof/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .template import Template
from .log import Log
import ollama as oll
from httpx import ConnectError # Import the ConnectError exception
from httpx import ConnectError

class Task:
def __init__(self, cfg):
Expand Down Expand Up @@ -41,18 +41,20 @@ def _log_error_and_raise(self, error_message: str, exception_message: str) -> No
raise Exception(exception_message)

def _render_prompt(self, prompt: str, response: str, template, system_prompt, instructions) -> str:
task_name = list(Template.TASKS.keys())[0]
return template.render(
system=system_prompt,
task=Const.VALIDATE_TASK, # Use the constant for the task name
task=task_name,
text=prompt,
example=Template.TASKS[Const.VALIDATE_TASK], # Use the corresponding task
example=Template.TASKS[task_name],
instructions=instructions,
lang=self.cfg.lang
)

def _post_process(self, prompt: str) -> str:
task_name = list(Template.TASKS.keys())[0]
replacements = {
"{{ task }}": Const.VALIDATE_TASK, # Use the constant for the task name
"{{ task }}": task_name,
"{{ lang }}": Const.LANG_DEFAULT
}
for key, value in replacements.items():
Expand Down
4 changes: 2 additions & 2 deletions oproof/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class Template:
"Example: {{ example }}\n"
"User: {{ prompt }}\n"
"Response: {{ response }}\n"
"System: Return only a JSON object with the validation result. No explanations, only JSON object; eg. {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}"
"System: Return only a JSON object with the validation result. No explanations, only JSON object; e.g., {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}"
)

TEMPLATES = {
"validation": T(PROMPT_TEMPLATE)
}

TASKS = {
"validate": "Validate the response for the given prompt."
"proofs": "Proof the given prompt and response pair of input text strings. e.g., 'What is 2 + 2?' '4' returns {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}"
}

0 comments on commit 29bf42c

Please sign in to comment.