-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorg OSS Diffusion Components to diffusion_labs folder (#480)
Summary: Based on this [proposal](https://docs.google.com/document/d/1GtN2urD8PiRr1X4COzvbVbNE8LWRrcoYkAO4v2aogO8/edit) to reorganize diffusion components and models under a new `diffusion_labs`. This is the first in a stack of diffs. This one only reorganizes what's already been moved to OSS. This is primarily moving files with a couple of changes based on the proposal: - predictors.py is split into a separate file per predictor - adm is moved out of dalle2 to be it's own model adm_unet - Dalle2ImageTransform is moved to dalle2 out of transforms - schedule.py is renamed to discrete_guassian_schedule.py and an abstract DIffusionSchedule class was added - An abstract adapter class was added to be a generic type and enforce the `forward` signature - An abstract sampler class was added to be a generic type and enforce the `forward` and `generator` signature - A new dalle2_model unit test was added Differential Revision: D49790849 Pulled By: pbontrager fbshipit-source-id: 98fe40c2418dc542cced940dc761a9cd602b0398
- Loading branch information
1 parent
f2cfe1a
commit b8226b9
Showing
45 changed files
with
654 additions
and
348 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#!/usr/bin/env fbpython | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from PIL import Image | ||
from tests.test_utils import assert_expected, set_rng_seed | ||
from torchmultimodal.diffusion_labs.models.dalle2.dalle2_decoder import dalle2_decoder | ||
from torchmultimodal.diffusion_labs.models.dalle2.transforms import Dalle2ImageTransform | ||
|
||
|
||
def test_dalle2_model(): | ||
set_rng_seed(4) | ||
model = dalle2_decoder( | ||
timesteps=1, | ||
time_embed_dim=1, | ||
cond_embed_dim=1, | ||
clip_embed_dim=1, | ||
clip_embed_name="clip_image", | ||
predict_variance_value=True, | ||
image_channels=1, | ||
depth=32, | ||
num_resize=1, | ||
num_res_per_layer=1, | ||
use_cf_guidance=True, | ||
clip_image_guidance_dropout=0.1, | ||
guidance_strength=7.0, | ||
learn_null_emb=True, | ||
) | ||
model.eval() | ||
x = torch.randn(1, 1, 4, 4) | ||
c = torch.ones((1, 1)) | ||
with torch.no_grad(): | ||
actual = model(x, conditional_inputs={"clip_image": c}).mean() | ||
expected = torch.as_tensor(0.12768) | ||
assert_expected(actual, expected, rtol=0, atol=1e-4) | ||
|
||
|
||
def test_dalle2_image_transform(): | ||
img_size = 5 | ||
transform = Dalle2ImageTransform(image_size=img_size, image_min=-1, image_max=1) | ||
image = Image.new("RGB", size=(20, 20), color=(128, 0, 0)) | ||
actual = transform(image).sum() | ||
normalized128 = 128 / 255 * 2 - 1 | ||
normalized0 = -1 | ||
expected = torch.tensor( | ||
normalized128 * img_size**2 + 2 * normalized0 * img_size**2 | ||
) | ||
assert_expected(actual, expected, rtol=0, atol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Oops, something went wrong.