-
Notifications
You must be signed in to change notification settings - Fork 8
/
prompts.py
40 lines (31 loc) · 1.03 KB
/
prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from importlib import resources
import os
import functools
import random
import inflect
IE = inflect.engine()
ASSETS_PATH = resources.files("assets")
@functools.cache
def _load_lines(path):
"""
Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
`assets` directory for a file named `path`.
"""
if not os.path.exists(path):
newpath = ASSETS_PATH.joinpath(path)
if not os.path.exists(newpath):
raise FileNotFoundError(f"Could not find {path} or assets/{path}")
path = newpath
with open(path, "r") as f:
return [line.strip() for line in f.readlines()]
def from_file(path, low=None, high=None):
prompts = _load_lines(path)[low:high]
return random.choice(prompts), {}
def hps_v2_all():
return from_file("hps_v2_all.txt")
def simple_animals():
return from_file("simple_animals.txt")
def eval_simple_animals():
return from_file("eval_simple_animals.txt")
def eval_hps_v2_all():
return from_file("hps_v2_all_eval.txt")