From 52b59c343443a2533c9724170f26bb653a817b46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Wed, 13 Dec 2023 01:28:43 +0800 Subject: [PATCH 1/2] fix load safetensors --- projects/powerpaint/gradio_PowerPaint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/powerpaint/gradio_PowerPaint.py b/projects/powerpaint/gradio_PowerPaint.py index ef2cd677a..8830d328a 100644 --- a/projects/powerpaint/gradio_PowerPaint.py +++ b/projects/powerpaint/gradio_PowerPaint.py @@ -11,6 +11,7 @@ StableDiffusionInpaintPipeline as Pipeline from pipeline.pipeline_PowerPaint_ControlNet import \ StableDiffusionControlNetInpaintPipeline as controlnetPipeline +from safetensors.torch import load_file from transformers import DPTFeatureExtractor, DPTForDepthEstimation from utils.utils import TokenizerWrapper, add_tokens @@ -34,7 +35,7 @@ initialize_tokens=['a', 'a', 'a'], num_vectors_per_token=10) pipe.unet.load_state_dict( - torch.load('./models/unet/diffusion_pytorch_model.bin'), strict=False) + load_file('./models/unet/diffusion_pytorch_model.safetensors', device='cuda'), strict=False) pipe.text_encoder.load_state_dict( torch.load('./models/text_encoder/pytorch_model.bin'), strict=False) pipe = pipe.to('cuda') From d3ba2c09e0f038583ae62964d523e9569cb1cc75 Mon Sep 17 00:00:00 2001 From: zengyh1900 Date: Fri, 15 Dec 2023 18:44:39 +0800 Subject: [PATCH 2/2] fix lint --- projects/powerpaint/gradio_PowerPaint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/projects/powerpaint/gradio_PowerPaint.py b/projects/powerpaint/gradio_PowerPaint.py index 8830d328a..11981560c 100644 --- a/projects/powerpaint/gradio_PowerPaint.py +++ b/projects/powerpaint/gradio_PowerPaint.py @@ -35,7 +35,9 @@ initialize_tokens=['a', 'a', 'a'], num_vectors_per_token=10) pipe.unet.load_state_dict( - load_file('./models/unet/diffusion_pytorch_model.safetensors', device='cuda'), strict=False) + load_file( + './models/unet/diffusion_pytorch_model.safetensors', device='cuda'), + strict=False) pipe.text_encoder.load_state_dict( torch.load('./models/text_encoder/pytorch_model.bin'), strict=False) pipe = pipe.to('cuda')