-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_diffusion_engine2.py
130 lines (112 loc) · 6.12 KB
/
test_diffusion_engine2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import numpy as np
import tensorrt as trt
from cuda import cudart
import torch
import datetime
from cuda import cudart
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
H = 256
W = 384
control_trt = "./models/enginemodels/sd_control_fp16.engine"
diffusion_trt = "./models/enginemodels/sd_diffusion_fp16.engine"
class hackathon():
def initialize(self):
self.logger = trt.Logger(trt.Logger.VERBOSE)
trt.init_libnvinfer_plugins(self.logger, '')
####################
with open(control_trt, "rb") as f:
engineString = f.read()
#生成推理引擎)``
self.control_engine = trt.Runtime(self.logger).deserialize_cuda_engine(engineString)
self.control_context = self.control_engine.create_execution_context()
#################
with open(diffusion_trt, "rb") as f:
engineString = f.read()
#生成推理引擎)
self.diffusion_engine = trt.Runtime(self.logger).deserialize_cuda_engine(engineString)
self.diffusion_context = self.diffusion_engine.create_execution_context()
############################创建输入输出buffer
start = datetime.datetime.now().timestamp()
control_nIO = self.control_engine.num_io_tensors
control_tensor_name = [self.control_engine.get_tensor_name(i) for i in range(control_nIO)]
diffusion_nIO = self.diffusion_engine.num_io_tensors
diffusion_tensor_name = [self.diffusion_engine.get_tensor_name(i) for i in range(diffusion_nIO)]
self.buffer_device = []
b, c, h, w = 1,4,H//8,W//8
self.control_out = []
for i in range(3):
temp = torch.zeros(b, 320, h, w, dtype=torch.float32).to("cuda")
self.control_out.append(temp)
self.buffer_device.append(temp.reshape(-1).data_ptr())
temp = torch.zeros(b, 320, h//2, w//2, dtype=torch.float32).to("cuda")
self.control_out.append(temp)
self.buffer_device.append(temp.reshape(-1).data_ptr())
for i in range(2):
temp = torch.zeros(b, 640, h//2, w//2, dtype=torch.float32).to("cuda")
self.control_out.append(temp)
self.buffer_device.append(temp.reshape(-1).data_ptr())
temp = torch.zeros(b, 640, h//4, w//4, dtype=torch.float32).to("cuda")
self.control_out.append(temp)
self.buffer_device.append(temp.reshape(-1).data_ptr())
for i in range(2):
temp = torch.zeros(b, 1280, h//4, w//4, dtype=torch.float32).to("cuda")
self.control_out.append(temp)
self.buffer_device.append(temp.reshape(-1).data_ptr())
for i in range(4):
temp = torch.zeros(b, 1280, h//8, w//8, dtype=torch.float32).to("cuda")
self.control_out.append(temp)
self.buffer_device.append(temp.reshape(-1).data_ptr())
self.eps = torch.zeros(1, 4, 32, 48, dtype=torch.float32).to("cuda")
self.buffer_device.append(self.eps.reshape(-1).data_ptr())
end = datetime.datetime.now().timestamp()
print("\n通过initialize节约的时间为:",(end-start)*1000)
def process(self):
###################################
x_noisy = torch.randn(1, 4, H//8, W //8, dtype=torch.float32).to("cuda")
hint_in = torch.randn(1, 3, H, W, dtype=torch.float32).to("cuda")
t = torch.zeros(1, dtype=torch.int64).to("cuda")
cond_txt = torch.randn(1, 77, 768, dtype=torch.float32).to("cuda")
self.buffer_device.append(x_noisy.reshape(-1).data_ptr())
self.buffer_device.append(hint_in.reshape(-1).data_ptr())
self.buffer_device.append(t.reshape(-1).data_ptr())
self.buffer_device.append(cond_txt.reshape(-1).data_ptr())
#execute
self.control_context.execute_v2(self.buffer_device)
###################################
self.x_in = torch.randn(1, 4, H//8, W //8, dtype=torch.float32).to("cuda")
self.time_in = torch.zeros(1, dtype=torch.int64).to("cuda")
self.context_in = torch.randn(1, 77, 768, dtype=torch.float32).to("cuda")
self.control = []
self.control.append(torch.randn(1, 320, H//8, W //8, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 320, H//8, W //8, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 320, H//8, W //8, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 320, H//16, W //16, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 640, H//16, W //16, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 640, H//16, W //16, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 640, H//32, W //32, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 1280, H//32, W //32, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 1280, H//32, W //32, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 1280, H//64, W //64, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 1280, H//64, W //64, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 1280, H//64, W //64, dtype=torch.float32).to("cuda"))
self.control.append(torch.randn(1, 1280, H//64, W //64, dtype=torch.float32).to("cuda"))
self.buffer_device.append(self.x_in.reshape(-1).data_ptr())
self.buffer_device.append(self.context_in.reshape(-1).data_ptr())
self.buffer_device.append(self.time_in.reshape(-1).data_ptr())
for temp in self.control:
self.buffer_device.append(temp.reshape(-1).data_ptr())
#execute
self.diffusion_context.execute_v2(self.buffer_device)
return self.eps
if __name__ == "__main__":
times = []
h = hackathon()
h.initialize()
start = datetime.datetime.now().timestamp()
h.process()
end = datetime.datetime.now().timestamp()
times.append((end - start)*1000)
# print("\ncontrolnet的输出为:",h.control_out)
print("\ndiffusion的输出为:",h.eps)
print("\n执行process流程,消耗时间为:", times)