|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 |
|
17 |
| -import inspect |
18 |
| -import tempfile |
19 | 17 | import unittest
|
20 | 18 |
|
21 | 19 | import numpy as np
|
22 | 20 | import torch
|
23 | 21 |
|
24 | 22 | from diffusers.models.embeddings import get_timestep_embedding
|
25 |
| -from diffusers.models.resnet import Downsample1D, Downsample2D, Upsample1D, Upsample2D |
26 |
| -from diffusers.testing_utils import floats_tensor, slow, torch_device |
| 23 | +from diffusers.models.resnet import Downsample2D, Upsample2D |
27 | 24 |
|
28 | 25 |
|
29 | 26 | torch.backends.cuda.matmul.allow_tf32 = False
|
@@ -219,108 +216,3 @@ def test_downsample_with_conv_out_dim(self):
|
219 | 216 | output_slice = downsampled[0, -1, -3:, -3:]
|
220 | 217 | expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
|
221 | 218 | assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
222 |
| - |
223 |
| - |
224 |
| -class Upsample1DBlockTests(unittest.TestCase): |
225 |
| - def test_upsample_default(self): |
226 |
| - torch.manual_seed(0) |
227 |
| - sample = torch.randn(1, 32, 32) |
228 |
| - upsample = Upsample1D(channels=32, use_conv=False) |
229 |
| - with torch.no_grad(): |
230 |
| - upsampled = upsample(sample) |
231 |
| - |
232 |
| - assert upsampled.shape == (1, 32, 64) |
233 |
| - output_slice = upsampled[0, -1, -8:] |
234 |
| - expected_slice = torch.tensor([-1.6340, -1.6340, 0.5374, 0.5374, 1.0826, 1.0826, -1.7105, -1.7105]) |
235 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
236 |
| - |
237 |
| - def test_upsample_with_conv(self): |
238 |
| - torch.manual_seed(0) |
239 |
| - sample = torch.randn(1, 32, 32) |
240 |
| - upsample = Upsample1D(channels=32, use_conv=True) |
241 |
| - with torch.no_grad(): |
242 |
| - upsampled = upsample(sample) |
243 |
| - |
244 |
| - assert upsampled.shape == (1, 32, 64) |
245 |
| - output_slice = upsampled[0, -1, -8:] |
246 |
| - expected_slice = torch.tensor([-0.4546, -0.5010, -0.2996, 0.2844, 0.4040, -0.7772, -0.6862, 0.3612]) |
247 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
248 |
| - |
249 |
| - def test_upsample_with_conv_out_dim(self): |
250 |
| - torch.manual_seed(0) |
251 |
| - sample = torch.randn(1, 32, 32) |
252 |
| - upsample = Upsample1D(channels=32, use_conv=True, out_channels=64) |
253 |
| - with torch.no_grad(): |
254 |
| - upsampled = upsample(sample) |
255 |
| - |
256 |
| - assert upsampled.shape == (1, 64, 64) |
257 |
| - output_slice = upsampled[0, -1, -8:] |
258 |
| - expected_slice = torch.tensor([-0.0516, -0.0972, 0.9740, 1.1883, 0.4539, -0.5285, -0.5851, 0.1152]) |
259 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
260 |
| - |
261 |
| - def test_upsample_with_transpose(self): |
262 |
| - torch.manual_seed(0) |
263 |
| - sample = torch.randn(1, 32, 32) |
264 |
| - upsample = Upsample1D(channels=32, use_conv=False, use_conv_transpose=True) |
265 |
| - with torch.no_grad(): |
266 |
| - upsampled = upsample(sample) |
267 |
| - |
268 |
| - assert upsampled.shape == (1, 32, 64) |
269 |
| - output_slice = upsampled[0, -1, -8:] |
270 |
| - expected_slice = torch.tensor([-0.2238, -0.5842, -0.7165, 0.6699, 0.1033, -0.4269, -0.8974, -0.3716]) |
271 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
272 |
| - |
273 |
| - |
274 |
| -class Downsample1DBlockTests(unittest.TestCase): |
275 |
| - def test_downsample_default(self): |
276 |
| - torch.manual_seed(0) |
277 |
| - sample = torch.randn(1, 32, 64) |
278 |
| - downsample = Downsample1D(channels=32, use_conv=False) |
279 |
| - with torch.no_grad(): |
280 |
| - downsampled = downsample(sample) |
281 |
| - |
282 |
| - assert downsampled.shape == (1, 32, 32) |
283 |
| - output_slice = downsampled[0, -1, -8:] |
284 |
| - expected_slice = torch.tensor([-0.8796, 1.0945, -0.3434, 0.2910, 0.3391, -0.4488, -0.9568, -0.2909]) |
285 |
| - max_diff = (output_slice.flatten() - expected_slice).abs().sum().item() |
286 |
| - assert max_diff <= 1e-3 |
287 |
| - # assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1) |
288 |
| - |
289 |
| - def test_downsample_with_conv(self): |
290 |
| - torch.manual_seed(0) |
291 |
| - sample = torch.randn(1, 32, 64) |
292 |
| - downsample = Downsample1D(channels=32, use_conv=True) |
293 |
| - with torch.no_grad(): |
294 |
| - downsampled = downsample(sample) |
295 |
| - |
296 |
| - assert downsampled.shape == (1, 32, 32) |
297 |
| - output_slice = downsampled[0, -1, -8:] |
298 |
| - |
299 |
| - expected_slice = torch.tensor( |
300 |
| - [0.1723, 0.0811, -0.6205, -0.3045, 0.0666, -0.2381, -0.0238, 0.2834], |
301 |
| - ) |
302 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
303 |
| - |
304 |
| - def test_downsample_with_conv_pad1(self): |
305 |
| - torch.manual_seed(0) |
306 |
| - sample = torch.randn(1, 32, 64) |
307 |
| - downsample = Downsample1D(channels=32, use_conv=True, padding=1) |
308 |
| - with torch.no_grad(): |
309 |
| - downsampled = downsample(sample) |
310 |
| - |
311 |
| - assert downsampled.shape == (1, 32, 32) |
312 |
| - output_slice = downsampled[0, -1, -8:] |
313 |
| - expected_slice = torch.tensor([0.1723, 0.0811, -0.6205, -0.3045, 0.0666, -0.2381, -0.0238, 0.2834]) |
314 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
315 |
| - |
316 |
| - def test_downsample_with_conv_out_dim(self): |
317 |
| - torch.manual_seed(0) |
318 |
| - sample = torch.randn(1, 32, 64) |
319 |
| - downsample = Downsample1D(channels=32, use_conv=True, out_channels=16) |
320 |
| - with torch.no_grad(): |
321 |
| - downsampled = downsample(sample) |
322 |
| - |
323 |
| - assert downsampled.shape == (1, 16, 32) |
324 |
| - output_slice = downsampled[0, -1, -8:] |
325 |
| - expected_slice = torch.tensor([1.1067, -0.5255, -0.4451, 0.0487, -0.3664, -0.7945, -0.4495, -0.3129]) |
326 |
| - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
0 commit comments