diff --git a/bmf/demo/controlnet/ReadMe.md b/bmf/demo/controlnet/ReadMe.md new file mode 100644 index 00000000..9c5efa14 --- /dev/null +++ b/bmf/demo/controlnet/ReadMe.md @@ -0,0 +1,31 @@ +# BMF ControlNet Demo + +This demo shows how to use ControlNet+StableDiffusion to generate image from text prompts in BMF. We use a performance-optimized ControlNet [implementation](https://github.com/NVIDIA/trt-samples-for-hackathon-cn/tree/master/Hackathon2023/controlnet). This implementation accelerates the canny2image app in the official ControlNet repo. + +You need to compile or install bmf before running the demo. Please refer to the [document](https://babitmf.github.io/docs/bmf/getting_started_yourself/install/) on how to build or install bmf. + +### Generate TensorRT Engine + +First we need to put the ControlNet code in the demo directory. This repo contains lots of samples of TensorRT, the ControlNet implementation we need in located in `trt-samples-for-hackathon-cn/Hackathon2023/controlnet` +```Bash +git clone https://github.com/NVIDIA/trt-samples-for-hackathon-cn.git +# copy the controlnet implementation to the demo path for simplicity +cp -r trt-samples-for-hackathon-cn/Hackathon2023/controlnet bmf/demo/controlnet +``` + +Download the state dict from HuggingFace and generate the TensorRT engine. You need to change the state dict path in `controlnet/export_onnx.py:19` to where you put the file. Then run `preprocess.sh` to build the TensorRT engine. +```Bash +cd bmf/demo/controlnet/controlnet/models +wget https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_canny.pth +# Change the path to './models/control_sd15_canny.pth' in controlnet/export_onnx.py:19 +cd .. # go back to the controlnet directory +bash preprocess.sh +``` + +Once the script runs successfully, several `.trt` files will be generated, which are the TensorRT engines. Copy the generated TensorRT engines to the directory of the demo and run the ControlNet pipeline using the `test_controlnet.py` script +```Bash +mv *.trt path/to/the/demo +cd path/to/the/demo +python test_controlnet.py +``` +The pipeline will generate a new image based on the input image and prompt. \ No newline at end of file diff --git a/bmf/demo/controlnet/controlnet_module.py b/bmf/demo/controlnet/controlnet_module.py new file mode 100644 index 00000000..2f4e65d6 --- /dev/null +++ b/bmf/demo/controlnet/controlnet_module.py @@ -0,0 +1,84 @@ +import sys +import random +from typing import List, Optional +import numpy as np +import pdb + +from bmf import * +import bmf.hml.hmp as mp +sys.path.append('./controlnet') +from canny2image_TRT import hackathon + +class controlnet_module(Module): + def __init__(self, node, option=None): + self.node_ = node + self.eof_received_ = False + self.hk = hackathon() + self.hk.initialize() + self.prompt_path = './prompt.txt' + self.eof_received_ = [False, False] + self.prompt_ = None + self.frame_list_ = [] + if 'path' in option.keys(): + self.prompt_path = option['path'] + + def process(self, task): + img_queue = task.get_inputs()[0] + pmt_queue = task.get_inputs()[1] + output_queue = task.get_outputs()[0] + + while not pmt_queue.empty(): + pmt_pkt = pmt_queue.get() + + if pmt_pkt.timestamp == Timestamp.EOF: + self.eof_received_[0] = True + else: + pmt = pmt_pkt.get(dict) + self.prompt_ = pmt + + while not img_queue.empty(): + in_pkt = img_queue.get() + + if in_pkt.timestamp == Timestamp.EOF: + self.eof_received_[1] = True + else: + self.frame_list_.append(in_pkt.get(VideoFrame)) + + while self.prompt_ and len(self.frame_list_) > 0: + in_frame = self.frame_list_[0] + del self.frame_list_[0] + + gen_img = self.hk.process(in_frame.cpu().frame().data()[0].numpy(), + pmt['prompt'], pmt['a_prompt'], pmt['n_prompt'], + 1, + 256, + 20, + False, + 1, + 9, + 2946901, + 0.0, + 100, + 200) + + rgbinfo = mp.PixelInfo(mp.PixelFormat.kPF_RGB24, + in_frame.frame().pix_info().space, + in_frame.frame().pix_info().range) + out_f = mp.Frame(mp.from_numpy(gen_img[0]), rgbinfo) + out_vf = VideoFrame(out_f) + out_vf.pts = in_frame.pts + out_vf.time_base = in_frame.time_base + out_pkt = Packet(out_vf) + out_pkt.timestamp = out_vf.pts + output_queue.put(out_pkt) + + if self.eof_received_[0] and self.eof_received_[1] and len(self.frame_list_) == 0: + output_queue.put(Packet.generate_eof_packet()) + Log.log_node(LogLevel.DEBUG, self.node_, 'output text stream', 'done') + task.set_timestamp(Timestamp.DONE) + return ProcessResult.OK + + return ProcessResult.OK + +def register_inpaint_module_info(info): + info.module_description = "ControlNet inference module" diff --git a/bmf/demo/controlnet/prompt.txt b/bmf/demo/controlnet/prompt.txt new file mode 100644 index 00000000..401a466d --- /dev/null +++ b/bmf/demo/controlnet/prompt.txt @@ -0,0 +1,3 @@ +prompt: a bird +a_prompt: best quality, extremely detailed +n_prompt: longbody, lowres, bad anatomy, bad hands, missing fingers \ No newline at end of file diff --git a/bmf/demo/controlnet/test_controlnet.py b/bmf/demo/controlnet/test_controlnet.py new file mode 100644 index 00000000..18497cef --- /dev/null +++ b/bmf/demo/controlnet/test_controlnet.py @@ -0,0 +1,60 @@ +import sys + +sys.path.append("../../") +import bmf + +sys.path.pop() + +def test(): + input_video_path = "./controlnet/test_imgs/bird.png" + input_prompt_path = "./prompt.txt" + output_path = "./output.png" + + graph = bmf.graph() + + # dual inputs + # ------------------------------------------------------------------------- + video = graph.decode({'input_path': input_video_path}) + prompt = video.module('text_module', {'path': input_prompt_path}) + + control=bmf.module(streams=[video, prompt], module_info='controlnet_module') + control.encode(None, {'output_path': output_path}).run() + + # sync mode + # from bmf import bmf_sync, Packet + # decoder = bmf_sync.sync_module("c_ffmpeg_decoder", {"input_path":"./ControlNet/test_imgs/bird.png"}, [], [0]) + # prompt = bmf_sync.sync_module('text_module', {'path': './prompt.txt'}, [], [1]) + # controlnet = bmf_sync.sync_module('controlnet_module', {}, [0, 1], [0]) + + # decoder.init() + # prompt.init() + # controlnet.init() + + # img, _ = bmf_sync.process(decoder, None) + # txt, _ = bmf_sync.process(prompt, None) + # gen_img, _ = bmf_sync.process(controlnet, {0: img[0], 1: txt[1]}) + # -------------------------------------------------------------------------- + + # video = graph.decode({ + # "input_path": input_video_path, + # # "video_params": { + # # "hwaccel": "cuda", + # # # "pix_fmt": "yuv420p", + # # } + # }) + # (video['video'] + # .module('controlnet', { + + # }) + # .encode( + # None, { + # "output_path": output_path, + # "video_params": { + # "codec": "png", + # # "pix_fmt": "cuda", + # } + # }).run()) + + +if __name__ == '__main__': + test() diff --git a/bmf/demo/controlnet/text_module.py b/bmf/demo/controlnet/text_module.py new file mode 100644 index 00000000..511a18e5 --- /dev/null +++ b/bmf/demo/controlnet/text_module.py @@ -0,0 +1,43 @@ +import sys +import random +from typing import List, Optional +import pdb + +from bmf import * +import bmf.hml.hmp as mp + +class text_module(Module): + def __init__(self, node, option=None): + self.node_ = node + self.eof_received_ = False + self.prompt_path = './prompt.txt' + if 'path' in option.keys(): + self.prompt_path = option['path'] + + def process(self, task): + input_packets = task.get_inputs()[0] + output_queue = task.get_outputs()[0] + + while not input_packets.empty(): + pkt = input_packets.get() + if pkt.timestamp == Timestamp.EOF: + output_queue.put(Packet.generate_eof_packet()) + Log.log_node(LogLevel.DEBUG, self.node_, 'output text stream', 'done') + task.set_timestamp(Timestamp.DONE) + return ProcessResult.OK + + prompt_dict = dict() + with open(self.prompt_path) as f: + for line in f: + pk, pt = line.partition(":")[::2] + prompt_dict[pk] = pt + + out_pkt = Packet(prompt_dict) + out_pkt.timestamp = 0 + output_queue.put(out_pkt) + # self.eof_received_ = True + + return ProcessResult.OK + +def register_inpaint_module_info(info): + info.module_description = "Text file IO module"