Skip to content

Commit

Permalink
Merge branch 'warmlogic-mm-jax-poetry'
Browse files Browse the repository at this point in the history
  • Loading branch information
xeb committed Jun 24, 2022
2 parents 26f063d + f5ac079 commit 44e50d7
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 83 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ A (soon-to-be) collection of tools for generating [dalle-mini](https://github.co
Install the dependencies, then try out the CLI. Try `python generate.py --help` for more.

```sh
# Install poetry
curl -sSL https://install.python-poetry.org | python3 -
# If you installed poetry 1.1.x before, uninstall first
curl -sSL https://install.python-poetry.org | python3 - --uninstall

# Install poetry 1.2.x preview
curl -sSL https://install.python-poetry.org | python3 - --preview

# Create virtual env for this project, install requirements
poetry install
Expand Down
66 changes: 47 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["poetry>=1.1"]
requires = ["poetry-core"]
build-backend = "poetry.masonry.api"

[tool.poetry]
Expand All @@ -14,15 +14,17 @@ fire = "^0.4.0"
Flask = "^2.1.2"
ipywidgets = "^7.7.1"
pySqsListener = "^0.8.10"
python = ">=3.10.4,<3.11"
python = ">=3.10,<3.11"
python-slugify = "^6.1.2"
tokenizers = "~=0.11.6"
vqgan-jax = {git = "https://github.com/patil-suraj/vqgan-jax.git", rev = "main"}
slack-sdk = "^3.17.2"
slack-bolt = "^1.14.0"
tqdm = "^4.64.0"
jax = "^0.3.13"
jaxlib = "^0.3.10"
jaxlib = [
{url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.10+cuda11.cudnn82-cp310-none-manylinux2014_x86_64.whl", markers = "platform_machine == 'linux'"},
{url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.3.10-cp310-none-macosx_11_0_arm64.whl", markers="platform_machine == 'arm64'"}
]

[tool.poetry.dev-dependencies]
black = "~22.3.0"
Expand Down
6 changes: 5 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def output(path):
return redirect(f"/output/{path}")

return render_template(
"template.html", prompt=prompt, imgs=imgs, expected_img_count=expectedimgs, show_links=True
"template.html",
prompt=prompt,
imgs=imgs,
expected_img_count=expectedimgs,
show_links=True,
)

else:
Expand Down
11 changes: 8 additions & 3 deletions sitegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def get_dir_details(path):
with open(f"{path}/prompt.txt", "r") as rp:
prompt = rp.read()

imgs = [ os.path.basename(x) for x in glob.glob(f"{path}/[!f]*.png") ] #a small hack to ignore "final.png"
return ( prompt, imgs )
imgs = [
os.path.basename(x) for x in glob.glob(f"{path}/[!f]*.png")
] # a small hack to ignore "final.png"
return (prompt, imgs)


def generate_index(path, show_links=False):
tl = jinja2.FileSystemLoader(searchpath="./templates")
Expand All @@ -53,7 +56,9 @@ def generate_index(path, show_links=False):
if imgs is None or len(imgs) == 0:
return None

return template.render(prompt=prompt, imgs=imgs, expected_img_count=len(imgs), show_links=show_links)
return template.render(
prompt=prompt, imgs=imgs, expected_img_count=len(imgs), show_links=show_links
)


if __name__ == "__main__":
Expand Down
104 changes: 61 additions & 43 deletions slackbot.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,76 @@
#!/usr/bin/env python

import os
import time
from tqdm import tqdm
from request import send as send_queue_request

from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
from tqdm import tqdm

from request import send as send_queue_request

# import time


SLACK_BOT_TOKEN = os.environ['SLACK_BOT_TOKEN']
SLACK_APP_TOKEN = os.environ['SLACK_APP_TOKEN']
SLACK_BOT_TOKEN = os.environ["SLACK_BOT_TOKEN"]
SLACK_APP_TOKEN = os.environ["SLACK_APP_TOKEN"]

app = App(token=SLACK_BOT_TOKEN)


@app.event("app_mention")
def mention_handler(body, say, logger):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
prompt = event["text"].replace(event["text"].split(' ')[0].strip(), "").strip() # i feel dirty but its late
print(f"Generating {prompt=}")
# start = time.time()
rundir = send_queue_request(prompt)
say(text=f'On it!', thread_ts=thread_ts)

max_t = 8000000 # my 2080Ti can generate from SQS to final image in: 1672800 ticks
for i in tqdm(range(max_t)):
if i % 100000 == 0:
print(f"Checking {rundir} {i}/{max_t}")

if os.path.exists(f'output/{rundir}/final.png'):
img = f'https://dalle-mini-tools.xeb.ai/output/{rundir}/final.png'
print(f"Found {img}")
say(img)
# say(img, thread_ts=thread_ts)
# end = time.time()
# duration = (start - end)
# # if duration >= 86400:
# # days = int(duration / 86400)
# elapsed = time.strftime("%H hours, %M minutes, %S seconds", time.gmtime(duration))
# say(f'Took me {elapsed}.', thread_ts=thread_ts)
return

say(text=f'...I gave up.', thread_ts=thread_ts)
def mention_handler_app_mention(body, say, logger):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
prompt = (
event["text"].replace(event["text"].split(" ")[0].strip(), "").strip()
) # i feel dirty but its late
print(f"Generating {prompt=}")
# start = time.time()
rundir = send_queue_request(prompt)
say(text="On it!", thread_ts=thread_ts)

max_t = 8000000 # my 2080Ti can generate from SQS to final image in: 1672800 ticks
for i in tqdm(range(max_t)):
if i % 100000 == 0:
print(f"Checking {rundir} {i}/{max_t}")

if os.path.exists(f"output/{rundir}/final.png"):
img = f"https://dalle-mini-tools.xeb.ai/output/{rundir}/final.png"
print(f"Found {img}")
say(img)

# say(img, thread_ts=thread_ts)
# end = time.time()
# duration = start - end
# # if duration >= 86400:
# # days = int(duration / 86400)
# elapsed = time.strftime(
# "%H hours, %M minutes, %S seconds", time.gmtime(duration)
# )
# say(f"Took me {elapsed}.", thread_ts=thread_ts)
return

say(text="...I gave up.", thread_ts=thread_ts)


@app.event("message")
def mention_handler(body, say):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
if "text" not in event:
return
def mention_handler_message(body, say):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
if "text" not in event:
return

message = event["text"].strip()
if "generation station" in message.lower():
say(
text=(
"What was that? Did you want to generate an image? Just mention me"
" (@ImageGen) and tell me what you want."
),
thread_ts=thread_ts,
)

message = event["text"].strip()
if "generation station" in message.lower():
say(text='What was that? Did you want to generate an image? Just mention me (@ImageGen) and tell me what you want.', thread_ts=thread_ts)

if __name__ == "__main__":
handler = SocketModeHandler(app, SLACK_APP_TOKEN)
handler.start()
handler = SocketModeHandler(app, SLACK_APP_TOKEN)
handler.start()
22 changes: 11 additions & 11 deletions worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python

import os
import fire
import subprocess
from generate import Generator

import fire
from sqs_listener import SqsListener

from generate import Generator
Expand All @@ -13,20 +13,19 @@ class ImgGenListener(SqsListener):
def init_model(self, output_dir, clip_scores, postprocess):
self.generator = Generator(output_dir, clip_scores)
self.postprocess = postprocess
print(f"Initialized model")
print("Initialized model")

def postprocessing(self, run_name):
if not self.postprocess:
print(f"Postprocessing not enabled, skipping...")
print("Postprocessing not enabled, skipping...")

if os.path.exists("postprocess.sh"):
cmds = [ "./postprocess.sh", run_name ]
cmds = ["./postprocess.sh", run_name]
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
if p.returncode != 0:
print("-"**5)
print("-" ** 5)
print(f"Exception\n{err=}\n\n{out=}")


def handle_message(self, body, attr, msg_attr):
print(f"Processing {body=} {attr=} {msg_attr=}")
Expand All @@ -36,12 +35,13 @@ def handle_message(self, body, attr, msg_attr):
self.postprocessing(run_name)
print(f"Processed! {body=}")


def main(
output_dir="output",
clip_scores=False,
output_dir="output",
clip_scores=False,
postprocess=True,
queue_name="dalle-mini-tools",
error_queue="dalle-mini-tools_errors",
queue_name="dalle-mini-tools",
error_queue="dalle-mini-tools_errors",
region_name="us-east-1",
interval=1,
):
Expand Down

0 comments on commit 44e50d7

Please sign in to comment.