Skip to content

Commit

Permalink
grep bin script. May delete later
Browse files Browse the repository at this point in the history
  • Loading branch information
Hudson Cooper committed Feb 23, 2024
1 parent cae7532 commit 67f18cf
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/minml/bin/grep3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python

import sys
import argparse
import json
from pydantic import RootModel, create_model
from pydantic.json_schema import to_jsonable_python
from pydantic.fields import PydanticUndefined

from gpts import Mistral
from guidance.models import LlamaCppChat
from minml.prompt.grep import GrepPrompt


def dumps(obj):
return json.dumps(obj, default=to_jsonable_python)


def get_model():
from minml.util import suppress_stdout_stderr

with suppress_stdout_stderr():
model = Mistral(verbose=False)
return LlamaCppChat(model.llm, echo=False)


def parse_args(argv):
if argv is None:
argv = sys.argv[1:]

parser = argparse.ArgumentParser()
parser.add_argument("irregex")
parser.add_argument("--schema", type=str)
parser.add_argument("file", nargs="?")
args = parser.parse_args(argv)

if args.file is None:
file = sys.stdin
else:
file = open(args.file)
text = file.read()

if args.schema:
schema = {}
for item in args.schema.split(","):
item = item.replace(" ", "")
name, *ta = item.split("=")
assert len(ta) <= 1
if not ta:
ta = str
else:
ta = eval(ta[0])
schema[name] = (ta, PydanticUndefined)
Schema = create_model(args.irregex.capitalize(), **schema)
else:
Schema = create_model(
args.irregex.capitalize(), root=(str, PydanticUndefined), __base__=RootModel
)
return argparse.Namespace(text=text, irregex=args.irregex, schema=Schema)


def main(argv=None):
args = parse_args(argv)

model = get_model()
prompt = GrepPrompt(model)
response = prompt(text=args.text, object_schema=args.schema, name=args.irregex)
# print("\n".join(dumps(c) for c in response.object))
print(response.prompt_with_completion)


if __name__ == "__main__":
main(sys.argv[1:])
36 changes: 36 additions & 0 deletions src/minml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,39 @@ def resolve_refs(schema, defs=None):
v = resolve_refs(v, defs)
new_schema[k] = v
return new_schema


# I hate LlamaCpp
class suppress_stdout_stderr(object):
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")

self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()

self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())

self.old_stdout = sys.stdout
self.old_stderr = sys.stderr

os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)

sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self

def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr

os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)

os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)

self.outnull_file.close()
self.errnull_file.close()

0 comments on commit 67f18cf

Please sign in to comment.