Skip to content

Commit

Permalink
[Dit] Fix dit tests (open-mmlab#2034)
Browse files Browse the repository at this point in the history
* [Dit] Fix dit tests

* up
  • Loading branch information
patrickvonplaten authored Jan 19, 2023
1 parent ed616bd commit 013955b
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions tests/pipelines/dit/test_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
sample_size=4,
sample_size=16,
num_layers=2,
patch_size=2,
attention_head_dim=2,
patch_size=4,
attention_head_dim=8,
num_attention_heads=2,
in_channels=4,
out_channels=8,
Expand Down Expand Up @@ -79,10 +79,8 @@ def test_inference(self):
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]

self.assertEqual(image.shape, (1, 4, 4, 3))
expected_slice = np.array(
[0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058]
)
self.assertEqual(image.shape, (1, 16, 16, 3))
expected_slice = np.array([0.4380, 0.4141, 0.5159, 0.0000, 0.4282, 0.6680, 0.5485, 0.2545, 0.6719])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)

Expand Down

0 comments on commit 013955b

Please sign in to comment.